--- language: - en license: mit tags: - table-structure-recognition - computer-vision - pytorch - document-ai - table-detection library_name: pytorch pipeline_tag: image-classification datasets: - ds4sd/FinTabNet_OTSL --- # TABLET Split Model - Table Structure Recognition This repository contains the **Split Model** implementation from the paper [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1), trained for detecting row and column splits in table images. ## Model Description The Split Model is a deep learning architecture designed to detect horizontal and vertical splits in table images, enabling accurate table structure recognition. The model processes a table image and predicts the positions of row and column boundaries. ### Architecture The model consists of three main components: 1. **Modified ResNet-18 Backbone** - Removed max pooling layer for better spatial resolution - Halved channel dimensions for efficiency (32→256 channels) - Outputs features at 1/16 resolution (60×60 for 960×960 input) 2. **Feature Pyramid Network (FPN)** - Upsamples backbone features to 1/2 resolution (480×480) - Reduces channels to 128 dimensions 3. **Dual Transformer Branches** - **Horizontal Branch**: Detects row splits using 1D transformer - **Vertical Branch**: Detects column splits using 1D transformer - Each branch combines: - Global features: Learnable weighted averaging - Local features: Spatial pooling with 1×1 convolution - Positional embeddings: 1D learned embeddings - 3-layer transformer encoder with 8 attention heads ### Training Details - **Dataset**: Combination of FinTabNet and PubTabNet (OTSL format) - **Input Size**: 960×960 pixels - **Batch Size**: 32 - **Epochs**: 16 - **Optimizer**: AdamW (lr=3e-4, weight_decay=5e-4) - **Loss Function**: Focal Loss (α=1.0, γ=2.0) - **Ground Truth**: Dynamic gap-based split detection from OTSL annotations ## Installation ```bash pip install torch torchvision pillow numpy ``` ## Usage ### Basic Inference ```python import torch from PIL import Image import torchvision.transforms as transforms from split_model import SplitModel # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SplitModel().to(device) # Load checkpoint checkpoint = torch.load('split_model.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Prepare image transform = transforms.Compose([ transforms.Resize((960, 960)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open('table_image.png').convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): h_pred, v_pred = model(image_tensor) # Returns [1, 480] predictions # Upsample to 960 for visualization h_pred = h_pred.repeat_interleave(2, dim=1) # [1, 960] v_pred = v_pred.repeat_interleave(2, dim=1) # [1, 960] # Apply threshold h_splits = (h_pred > 0.5).float() v_splits = (v_pred > 0.5).float() # Count rows and columns num_rows = h_splits.sum().item() + 1 num_cols = v_splits.sum().item() + 1 print(f"Detected {num_rows} rows and {num_cols} columns") ``` ### Visualize Predictions Use the included visualization script to test on your images: ```bash python test_split_by_images_folder.py \ --image-folder /path/to/images \ --output-folder predictions_output \ --model-path split_model.pth \ --threshold 0.5 ``` ## Model Performance The model was trained on combined FinTabNet and PubTabNet datasets: - Training samples: ~250K table images - Validation F1 scores typically achieve >0.90 for both horizontal and vertical splits - Robust to various table styles, merged cells, and complex layouts ## Files in this Repository - `split_model.py` - Model architecture and dataset classes - `train_split_fixed.py` - Training script - `test_split_by_images_folder.py` - Inference and visualization script - `split_model.pth` - Trained model weights ## Key Features - **Dynamic Gap Detection**: Automatically handles varying gap widths between cells - **Overlap Handling**: Correctly processes tables with overlapping cell boundaries - **Focal Loss Training**: Addresses class imbalance between split and non-split pixels - **Transformer-based**: Captures long-range dependencies for complex table structures ## Citation If you use this model, please cite the original TABLET paper: ```bibtex @article{tablet2025, title={TABLET: Learning From Instructions For Tabular Data}, author={[Authors from paper]}, journal={arXiv preprint arXiv:2506.07015}, year={2025} } ``` ## Paper Reference This implementation is based on the Split Model described in Section 3.2 of: [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1) ## License This model is released for research purposes. Please refer to the original paper for more details. ## Acknowledgments - Original paper authors for the TABLET framework - FinTabNet and PubTabNet datasets for training data - PyTorch team for the deep learning framework