tablet-split-model / README.md
santhoshkammari's picture
Upload README.md with huggingface_hub
d967216 verified
---
language:
- en
license: mit
tags:
- table-structure-recognition
- computer-vision
- pytorch
- document-ai
- table-detection
library_name: pytorch
pipeline_tag: image-classification
datasets:
- ds4sd/FinTabNet_OTSL
---
# TABLET Split Model - Table Structure Recognition
This repository contains the **Split Model** implementation from the paper [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1), trained for detecting row and column splits in table images.
## Model Description
The Split Model is a deep learning architecture designed to detect horizontal and vertical splits in table images, enabling accurate table structure recognition. The model processes a table image and predicts the positions of row and column boundaries.
### Architecture
The model consists of three main components:
1. **Modified ResNet-18 Backbone**
- Removed max pooling layer for better spatial resolution
- Halved channel dimensions for efficiency (32→256 channels)
- Outputs features at 1/16 resolution (60×60 for 960×960 input)
2. **Feature Pyramid Network (FPN)**
- Upsamples backbone features to 1/2 resolution (480×480)
- Reduces channels to 128 dimensions
3. **Dual Transformer Branches**
- **Horizontal Branch**: Detects row splits using 1D transformer
- **Vertical Branch**: Detects column splits using 1D transformer
- Each branch combines:
- Global features: Learnable weighted averaging
- Local features: Spatial pooling with 1×1 convolution
- Positional embeddings: 1D learned embeddings
- 3-layer transformer encoder with 8 attention heads
### Training Details
- **Dataset**: Combination of FinTabNet and PubTabNet (OTSL format)
- **Input Size**: 960×960 pixels
- **Batch Size**: 32
- **Epochs**: 16
- **Optimizer**: AdamW (lr=3e-4, weight_decay=5e-4)
- **Loss Function**: Focal Loss (α=1.0, γ=2.0)
- **Ground Truth**: Dynamic gap-based split detection from OTSL annotations
## Installation
```bash
pip install torch torchvision pillow numpy
```
## Usage
### Basic Inference
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from split_model import SplitModel
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SplitModel().to(device)
# Load checkpoint
checkpoint = torch.load('split_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Prepare image
transform = transforms.Compose([
transforms.Resize((960, 960)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('table_image.png').convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
h_pred, v_pred = model(image_tensor) # Returns [1, 480] predictions
# Upsample to 960 for visualization
h_pred = h_pred.repeat_interleave(2, dim=1) # [1, 960]
v_pred = v_pred.repeat_interleave(2, dim=1) # [1, 960]
# Apply threshold
h_splits = (h_pred > 0.5).float()
v_splits = (v_pred > 0.5).float()
# Count rows and columns
num_rows = h_splits.sum().item() + 1
num_cols = v_splits.sum().item() + 1
print(f"Detected {num_rows} rows and {num_cols} columns")
```
### Visualize Predictions
Use the included visualization script to test on your images:
```bash
python test_split_by_images_folder.py \
--image-folder /path/to/images \
--output-folder predictions_output \
--model-path split_model.pth \
--threshold 0.5
```
## Model Performance
The model was trained on combined FinTabNet and PubTabNet datasets:
- Training samples: ~250K table images
- Validation F1 scores typically achieve >0.90 for both horizontal and vertical splits
- Robust to various table styles, merged cells, and complex layouts
## Files in this Repository
- `split_model.py` - Model architecture and dataset classes
- `train_split_fixed.py` - Training script
- `test_split_by_images_folder.py` - Inference and visualization script
- `split_model.pth` - Trained model weights
## Key Features
- **Dynamic Gap Detection**: Automatically handles varying gap widths between cells
- **Overlap Handling**: Correctly processes tables with overlapping cell boundaries
- **Focal Loss Training**: Addresses class imbalance between split and non-split pixels
- **Transformer-based**: Captures long-range dependencies for complex table structures
## Citation
If you use this model, please cite the original TABLET paper:
```bibtex
@article{tablet2025,
title={TABLET: Learning From Instructions For Tabular Data},
author={[Authors from paper]},
journal={arXiv preprint arXiv:2506.07015},
year={2025}
}
```
## Paper Reference
This implementation is based on the Split Model described in Section 3.2 of:
[TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1)
## License
This model is released for research purposes. Please refer to the original paper for more details.
## Acknowledgments
- Original paper authors for the TABLET framework
- FinTabNet and PubTabNet datasets for training data
- PyTorch team for the deep learning framework