zen-vl-training / app.py
Hanzo Dev
Add auto-start training on Space launch
62858a8
"""
Zen VL Training Space - HuggingFace Pro GPU Training
Trains zen-vl-4b with combined ADP+xLAM datasets
"""
import os
import sys
import time
import json
import random
import logging
from pathlib import Path
from typing import List, Dict, Any
import torch
from transformers import (
Qwen3VLForConditionalGeneration,
Qwen3VLProcessor,
TrainingArguments,
Trainer,
)
from datasets import load_dataset, Dataset
import gradio as gr
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global training state
training_state = {
"status": "idle",
"progress": 0,
"current_step": 0,
"total_steps": 0,
"loss": 0.0,
"logs": []
}
def log_message(message: str):
"""Add message to training logs"""
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
log_entry = f"[{timestamp}] {message}"
training_state["logs"].append(log_entry)
logger.info(message)
return log_entry
class ZenVLTrainer:
def __init__(self, model_size="4b", gpu_type="a10g"):
self.model_size = model_size
self.gpu_type = gpu_type
self.model_name = f"zenlm/zen-vl-{model_size}-instruct"
self.output_name = f"zenlm/zen-vl-{model_size}-agent"
# GPU-specific configs
self.configs = {
"a10g": {
"batch_size": 1,
"gradient_accumulation": 8,
"max_samples": 30000,
"learning_rate": 2e-5,
},
"a100-large": {
"batch_size": 2,
"gradient_accumulation": 8,
"max_samples": 50000,
"learning_rate": 2e-5,
},
"a100": {
"batch_size": 4,
"gradient_accumulation": 8,
"max_samples": 100000,
"learning_rate": 2e-5,
}
}
self.config = self.configs.get(gpu_type, self.configs["a10g"])
log_message(f"Initialized Zen VL Trainer for {model_size} on {gpu_type}")
log_message(f"Config: {self.config}")
def load_adp_data(self, max_samples: int = None) -> List[Dict[str, Any]]:
"""Load Agent Data Protocol dataset"""
log_message("Loading ADP dataset...")
data_dir = Path("data/adp")
all_data = []
if data_dir.exists():
# Load from local cache
for json_file in data_dir.glob("*.jsonl"):
log_message(f"Loading {json_file.name}...")
with open(json_file, 'r') as f:
for line in f:
if line.strip():
all_data.append(json.loads(line))
if max_samples and len(all_data) >= max_samples:
break
else:
# Download from HuggingFace
log_message("Downloading ADP dataset from HuggingFace...")
configs = [
'agenttuning_os', 'agenttuning_kg', 'agenttuning_db',
'synatra', 'code_feedback', 'go-browse-wa'
]
for config in configs:
try:
dataset = load_dataset(
"neulab/agent-data-collection",
config,
split="train",
streaming=True
)
for i, example in enumerate(dataset):
all_data.append(example)
if max_samples and len(all_data) >= max_samples:
break
log_message(f"Loaded {len(all_data)} samples from {config}")
if max_samples and len(all_data) >= max_samples:
break
except Exception as e:
log_message(f"Warning: Could not load {config}: {e}")
continue
log_message(f"Loaded {len(all_data)} ADP samples")
return all_data
def load_xlam_data(self, max_samples: int = None) -> List[Dict[str, Any]]:
"""Load xLAM function calling dataset"""
log_message("Loading xLAM dataset...")
data_dir = Path("data/xlam")
all_data = []
if data_dir.exists():
# Load from local cache
json_file = data_dir / "xlam_converted.jsonl"
if json_file.exists():
with open(json_file, 'r') as f:
for line in f:
if line.strip():
all_data.append(json.loads(line))
if max_samples and len(all_data) >= max_samples:
break
else:
# Download from HuggingFace
log_message("Downloading xLAM dataset from HuggingFace...")
try:
dataset = load_dataset(
"Salesforce/xlam-function-calling-60k",
split="train"
)
for i, example in enumerate(dataset):
all_data.append(example)
if max_samples and len(all_data) >= max_samples:
break
log_message(f"Loaded {len(all_data)} xLAM samples")
except Exception as e:
log_message(f"Error loading xLAM: {e}")
return all_data
def create_balanced_mixture(
self,
adp_data: List[Dict],
xlam_data: List[Dict],
adp_weight: float = 0.80,
xlam_weight: float = 0.20
) -> List[Dict]:
"""Create balanced mixture of ADP and xLAM data"""
log_message(f"Creating balanced mixture: {adp_weight:.0%} ADP, {xlam_weight:.0%} xLAM")
total_size = min(len(adp_data), int(len(xlam_data) / xlam_weight))
adp_target = int(total_size * adp_weight)
xlam_target = int(total_size * xlam_weight)
adp_sample = random.sample(adp_data, min(adp_target, len(adp_data)))
xlam_sample = random.sample(xlam_data, min(xlam_target, len(xlam_data)))
combined = adp_sample + xlam_sample
random.shuffle(combined)
log_message(f"Created mixture: {len(adp_sample)} ADP + {len(xlam_sample)} xLAM = {len(combined)} total")
return combined
def train(self):
"""Main training function"""
try:
training_state["status"] = "preparing"
log_message("=" * 80)
log_message("Starting Zen VL Training on HuggingFace Space")
log_message("=" * 80)
# Load model and processor
training_state["status"] = "loading_model"
log_message(f"Loading model: {self.model_name}")
model = Qwen3VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = Qwen3VLProcessor.from_pretrained(self.model_name)
log_message("Model and processor loaded successfully")
# Load datasets
training_state["status"] = "loading_data"
max_samples = self.config["max_samples"]
adp_data = self.load_adp_data(max_samples=int(max_samples * 0.8))
xlam_data = self.load_xlam_data(max_samples=int(max_samples * 0.2))
# Create mixture
combined_data = self.create_balanced_mixture(adp_data, xlam_data)
# Convert to HuggingFace Dataset
dataset = Dataset.from_list(combined_data)
log_message(f"Created dataset with {len(dataset)} examples")
# Training arguments
training_state["status"] = "configuring"
output_dir = f"./output/{self.model_size}"
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=self.config["batch_size"],
gradient_accumulation_steps=self.config["gradient_accumulation"],
learning_rate=self.config["learning_rate"],
warmup_steps=500,
logging_steps=10,
save_steps=500,
save_total_limit=3,
fp16=False,
bf16=True,
push_to_hub=True,
hub_model_id=self.output_name,
hub_strategy="every_save",
report_to="tensorboard",
)
log_message("Training configuration:")
log_message(f" Epochs: {training_args.num_train_epochs}")
log_message(f" Batch size: {training_args.per_device_train_batch_size}")
log_message(f" Gradient accumulation: {training_args.gradient_accumulation_steps}")
log_message(f" Learning rate: {training_args.learning_rate}")
log_message(f" Total samples: {len(dataset)}")
# Calculate total steps
total_steps = (
len(dataset)
// (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)
* training_args.num_train_epochs
)
training_state["total_steps"] = total_steps
log_message(f" Total training steps: {total_steps}")
# Initialize trainer
training_state["status"] = "training"
log_message("Initializing trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
# Start training
log_message("=" * 80)
log_message("TRAINING STARTED")
log_message("=" * 80)
result = trainer.train()
# Training completed
training_state["status"] = "uploading"
log_message("=" * 80)
log_message("TRAINING COMPLETED")
log_message("=" * 80)
log_message(f"Final loss: {result.training_loss:.4f}")
# Push to hub
log_message(f"Uploading model to {self.output_name}...")
trainer.push_to_hub()
training_state["status"] = "completed"
training_state["progress"] = 100
log_message("=" * 80)
log_message("SUCCESS! Model uploaded to HuggingFace")
log_message("=" * 80)
return "Training completed successfully!"
except Exception as e:
training_state["status"] = "error"
error_msg = f"Training failed: {str(e)}"
log_message(error_msg)
return error_msg
def get_training_status():
"""Get current training status for Gradio UI"""
status = training_state["status"]
progress = training_state["progress"]
current_step = training_state["current_step"]
total_steps = training_state["total_steps"]
loss = training_state["loss"]
status_text = {
"idle": "⏸️ Ready to start training",
"preparing": "πŸ”§ Preparing training environment...",
"loading_model": "πŸ“¦ Loading model and processor...",
"loading_data": "πŸ“š Loading training datasets...",
"configuring": "βš™οΈ Configuring training parameters...",
"training": f"πŸš€ Training in progress: {current_step}/{total_steps} steps",
"uploading": "☁️ Uploading model to HuggingFace...",
"completed": "βœ… Training completed successfully!",
"error": "❌ Training failed"
}
return status_text.get(status, status), progress, "\n".join(training_state["logs"][-50:])
def start_training(model_size, gpu_type):
"""Start training job"""
log_message(f"Starting training for {model_size} on {gpu_type}")
trainer = ZenVLTrainer(model_size=model_size, gpu_type=gpu_type)
result = trainer.train()
return result
# Gradio Interface
with gr.Blocks(title="Zen VL Training") as demo:
gr.Markdown("""
# 🧘 Zen VL Training Space
Train zen-vl models with combined ADP+xLAM datasets on HuggingFace Pro GPUs.
**Datasets:**
- Agent Data Protocol (ADP): ~220k agent trajectories
- xLAM Function Calling: 60k function calling examples
**Training Time Estimates:**
- 4B model on A10G: ~6-8 hours
- 8B model on A100: ~10-12 hours
- 30B model on A100-80GB: ~20-24 hours
""")
with gr.Row():
model_size = gr.Dropdown(
choices=["4b", "8b", "30b"],
value="4b",
label="Model Size"
)
gpu_type = gr.Dropdown(
choices=["a10g", "a100-large", "a100"],
value="a10g",
label="GPU Type"
)
start_btn = gr.Button("πŸš€ Start Training", variant="primary")
status_text = gr.Textbox(label="Status", value="Ready to start training")
progress_bar = gr.Slider(minimum=0, maximum=100, value=0, label="Progress")
logs_text = gr.Textbox(label="Training Logs", lines=20, max_lines=50)
# Auto-refresh status every 10 seconds
demo.load(
get_training_status,
None,
[status_text, progress_bar, logs_text],
every=10
)
start_btn.click(
start_training,
inputs=[model_size, gpu_type],
outputs=[status_text]
)
if __name__ == "__main__":
# Check if running in HF Space
if os.environ.get("SPACE_ID"):
log_message(f"Running in HuggingFace Space: {os.environ['SPACE_ID']}")
# Launch Gradio interface
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
# AUTO-START TRAINING
import threading
def auto_start_training():
"""Auto-start training when Space launches"""
time.sleep(5) # Wait for Space to fully initialize
print("πŸš€ AUTO-STARTING TRAINING (4B model on A10G)")
trainer = ZenVLTrainer(model_size="4b", gpu_type="a10g")
trainer.train()
# Launch training in background thread
training_thread = threading.Thread(target=auto_start_training, daemon=True)
training_thread.start()