Goldenwert's picture
Upload Multi-Token Prediction GPT-2 model
4684d35 verified
---
language: en
license: apache-2.0
library_name: transformers
tags:
- multi-token-prediction
- gpt2
- mathematics
- MetaMathQA
- transformer
- speculative-decoding
pipeline_tag: text-generation
base_model: gpt2
---
# Multi-Token Prediction GPT-2 Model
This is a GPT-2 model enhanced with **Multi-Token Prediction (MTP)** architecture, trained on the MetaMathQA dataset for mathematical reasoning tasks.
## Model Description
This model implements the Multi-Token Prediction approach from the paper "Better & Faster Large Language Models via Multi-token Prediction" by Meta AI. Key features:
- **Shared Trunk Architecture**: Uses 9 shared transformer layers with 4 prediction heads
- **Multi-Token Prediction**: Predicts 4 tokens simultaneously (t+1, t+2, t+3, t+4)
- **Enhanced Speculative Decoding**: Achieves up to 3x inference speedup
- **Mathematical Reasoning**: Fine-tuned specifically for mathematical problem solving
## Architecture Details
- **Base Model**: GPT-2 (124M parameters)
- **Trunk Layers**: 9 (shared processing)
- **Prediction Heads**: 4 parallel heads
- **Training Data**: MetaMathQA dataset (500 samples)
- **Training Epochs**: 1 with gradient accumulation
## Usage
```python
import torch
from transformers import GPT2Tokenizer
# Note: You'll need the custom MultiTokenGPT2 class from the training code
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Load your model (custom loading required)
# model = MultiTokenGPT2.from_pretrained("Goldenwert/multitoken-gpt2-metamathqa")
prompt = "What is the derivative of f(x) = x^3?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Standard generation
generated = model.generate(
input_ids,
max_new_tokens=50,
use_speculative=False
)
# Fast speculative generation
generated_fast = model.generate(
input_ids,
max_new_tokens=50,
use_speculative=True
)
print(tokenizer.decode(generated[0], skip_special_tokens=True))
```
## Performance
- **Inference Speed**: Up to 3x faster with speculative decoding
- **Memory Efficiency**: Gradient checkpointing support
- **Mathematical Tasks**: Improved reasoning on math problems
## Training Details
- **Dataset**: MetaMathQA (mathematical reasoning)
- **Optimizer**: AdamW with warmup
- **Learning Rate**: 5e-5
- **Batch Size**: 4 (effective 32 with gradient accumulation)
- **Hardware**: GPU with FP16 precision
## Research Context
Based on the paper "Better & Faster Large Language Models via Multi-token Prediction" which demonstrates that predicting multiple tokens simultaneously can improve both training efficiency and inference speed.
## Files
- `pytorch_model.bin`: Model weights
- `config.json`: Model configuration
- Additional training artifacts
## License
Apache 2.0 (following GPT-2 base model)