|
|
--- |
|
|
language: en |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- multi-token-prediction |
|
|
- gpt2 |
|
|
- mathematics |
|
|
- MetaMathQA |
|
|
- transformer |
|
|
- speculative-decoding |
|
|
pipeline_tag: text-generation |
|
|
base_model: gpt2 |
|
|
--- |
|
|
|
|
|
# Multi-Token Prediction GPT-2 Model |
|
|
|
|
|
This is a GPT-2 model enhanced with **Multi-Token Prediction (MTP)** architecture, trained on the MetaMathQA dataset for mathematical reasoning tasks. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model implements the Multi-Token Prediction approach from the paper "Better & Faster Large Language Models via Multi-token Prediction" by Meta AI. Key features: |
|
|
|
|
|
- **Shared Trunk Architecture**: Uses 9 shared transformer layers with 4 prediction heads |
|
|
- **Multi-Token Prediction**: Predicts 4 tokens simultaneously (t+1, t+2, t+3, t+4) |
|
|
- **Enhanced Speculative Decoding**: Achieves up to 3x inference speedup |
|
|
- **Mathematical Reasoning**: Fine-tuned specifically for mathematical problem solving |
|
|
|
|
|
## Architecture Details |
|
|
|
|
|
- **Base Model**: GPT-2 (124M parameters) |
|
|
- **Trunk Layers**: 9 (shared processing) |
|
|
- **Prediction Heads**: 4 parallel heads |
|
|
- **Training Data**: MetaMathQA dataset (500 samples) |
|
|
- **Training Epochs**: 1 with gradient accumulation |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import GPT2Tokenizer |
|
|
# Note: You'll need the custom MultiTokenGPT2 class from the training code |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
# Load your model (custom loading required) |
|
|
# model = MultiTokenGPT2.from_pretrained("Goldenwert/multitoken-gpt2-metamathqa") |
|
|
|
|
|
prompt = "What is the derivative of f(x) = x^3?" |
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
# Standard generation |
|
|
generated = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=50, |
|
|
use_speculative=False |
|
|
) |
|
|
|
|
|
# Fast speculative generation |
|
|
generated_fast = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=50, |
|
|
use_speculative=True |
|
|
) |
|
|
|
|
|
print(tokenizer.decode(generated[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Performance |
|
|
|
|
|
- **Inference Speed**: Up to 3x faster with speculative decoding |
|
|
- **Memory Efficiency**: Gradient checkpointing support |
|
|
- **Mathematical Tasks**: Improved reasoning on math problems |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Dataset**: MetaMathQA (mathematical reasoning) |
|
|
- **Optimizer**: AdamW with warmup |
|
|
- **Learning Rate**: 5e-5 |
|
|
- **Batch Size**: 4 (effective 32 with gradient accumulation) |
|
|
- **Hardware**: GPU with FP16 precision |
|
|
|
|
|
## Research Context |
|
|
|
|
|
Based on the paper "Better & Faster Large Language Models via Multi-token Prediction" which demonstrates that predicting multiple tokens simultaneously can improve both training efficiency and inference speed. |
|
|
|
|
|
## Files |
|
|
|
|
|
- `pytorch_model.bin`: Model weights |
|
|
- `config.json`: Model configuration |
|
|
- Additional training artifacts |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 (following GPT-2 base model) |
|
|
|