|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: mit |
|
|
tags: |
|
|
- table-structure-recognition |
|
|
- computer-vision |
|
|
- pytorch |
|
|
- document-ai |
|
|
- table-detection |
|
|
library_name: pytorch |
|
|
pipeline_tag: image-classification |
|
|
datasets: |
|
|
- ds4sd/FinTabNet_OTSL |
|
|
--- |
|
|
|
|
|
# TABLET Split Model - Table Structure Recognition |
|
|
|
|
|
This repository contains the **Split Model** implementation from the paper [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1), trained for detecting row and column splits in table images. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
The Split Model is a deep learning architecture designed to detect horizontal and vertical splits in table images, enabling accurate table structure recognition. The model processes a table image and predicts the positions of row and column boundaries. |
|
|
|
|
|
### Architecture |
|
|
|
|
|
The model consists of three main components: |
|
|
|
|
|
1. **Modified ResNet-18 Backbone** |
|
|
- Removed max pooling layer for better spatial resolution |
|
|
- Halved channel dimensions for efficiency (32→256 channels) |
|
|
- Outputs features at 1/16 resolution (60×60 for 960×960 input) |
|
|
|
|
|
2. **Feature Pyramid Network (FPN)** |
|
|
- Upsamples backbone features to 1/2 resolution (480×480) |
|
|
- Reduces channels to 128 dimensions |
|
|
|
|
|
3. **Dual Transformer Branches** |
|
|
- **Horizontal Branch**: Detects row splits using 1D transformer |
|
|
- **Vertical Branch**: Detects column splits using 1D transformer |
|
|
- Each branch combines: |
|
|
- Global features: Learnable weighted averaging |
|
|
- Local features: Spatial pooling with 1×1 convolution |
|
|
- Positional embeddings: 1D learned embeddings |
|
|
- 3-layer transformer encoder with 8 attention heads |
|
|
|
|
|
### Training Details |
|
|
|
|
|
- **Dataset**: Combination of FinTabNet and PubTabNet (OTSL format) |
|
|
- **Input Size**: 960×960 pixels |
|
|
- **Batch Size**: 32 |
|
|
- **Epochs**: 16 |
|
|
- **Optimizer**: AdamW (lr=3e-4, weight_decay=5e-4) |
|
|
- **Loss Function**: Focal Loss (α=1.0, γ=2.0) |
|
|
- **Ground Truth**: Dynamic gap-based split detection from OTSL annotations |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision pillow numpy |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Basic Inference |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
from split_model import SplitModel |
|
|
|
|
|
# Load model |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = SplitModel().to(device) |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint = torch.load('split_model.pth', map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Prepare image |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((960, 960)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
image = Image.open('table_image.png').convert('RGB') |
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
h_pred, v_pred = model(image_tensor) # Returns [1, 480] predictions |
|
|
|
|
|
# Upsample to 960 for visualization |
|
|
h_pred = h_pred.repeat_interleave(2, dim=1) # [1, 960] |
|
|
v_pred = v_pred.repeat_interleave(2, dim=1) # [1, 960] |
|
|
|
|
|
# Apply threshold |
|
|
h_splits = (h_pred > 0.5).float() |
|
|
v_splits = (v_pred > 0.5).float() |
|
|
|
|
|
# Count rows and columns |
|
|
num_rows = h_splits.sum().item() + 1 |
|
|
num_cols = v_splits.sum().item() + 1 |
|
|
|
|
|
print(f"Detected {num_rows} rows and {num_cols} columns") |
|
|
``` |
|
|
|
|
|
### Visualize Predictions |
|
|
|
|
|
Use the included visualization script to test on your images: |
|
|
|
|
|
```bash |
|
|
python test_split_by_images_folder.py \ |
|
|
--image-folder /path/to/images \ |
|
|
--output-folder predictions_output \ |
|
|
--model-path split_model.pth \ |
|
|
--threshold 0.5 |
|
|
``` |
|
|
|
|
|
## Model Performance |
|
|
|
|
|
The model was trained on combined FinTabNet and PubTabNet datasets: |
|
|
- Training samples: ~250K table images |
|
|
- Validation F1 scores typically achieve >0.90 for both horizontal and vertical splits |
|
|
- Robust to various table styles, merged cells, and complex layouts |
|
|
|
|
|
## Files in this Repository |
|
|
|
|
|
- `split_model.py` - Model architecture and dataset classes |
|
|
- `train_split_fixed.py` - Training script |
|
|
- `test_split_by_images_folder.py` - Inference and visualization script |
|
|
- `split_model.pth` - Trained model weights |
|
|
|
|
|
## Key Features |
|
|
|
|
|
- **Dynamic Gap Detection**: Automatically handles varying gap widths between cells |
|
|
- **Overlap Handling**: Correctly processes tables with overlapping cell boundaries |
|
|
- **Focal Loss Training**: Addresses class imbalance between split and non-split pixels |
|
|
- **Transformer-based**: Captures long-range dependencies for complex table structures |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite the original TABLET paper: |
|
|
|
|
|
```bibtex |
|
|
@article{tablet2025, |
|
|
title={TABLET: Learning From Instructions For Tabular Data}, |
|
|
author={[Authors from paper]}, |
|
|
journal={arXiv preprint arXiv:2506.07015}, |
|
|
year={2025} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Paper Reference |
|
|
|
|
|
This implementation is based on the Split Model described in Section 3.2 of: |
|
|
[TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1) |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released for research purposes. Please refer to the original paper for more details. |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Original paper authors for the TABLET framework |
|
|
- FinTabNet and PubTabNet datasets for training data |
|
|
- PyTorch team for the deep learning framework |
|
|
|