ague-what / app.py
sajofu's picture
Create app.py
f129010 verified
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AdamW, get_scheduler
from datasets import load_dataset
from PIL import Image
import os
from tqdm import tqdm
# --- 1. Configuration ---
# A simple class to hold our configuration
class Config:
# Model IDs
IMAGE_ENCODER_ID = "unum-cloud/uform3-image-text-english-large"
TEXT_MODEL_ID = "Qwen/Qwen1.5-0.5B-Chat"
# Dataset
DATASET_ID = "recastai/LAION-art-EN-improved-captions"
# Training Parameters
LR = 5e-5
NUM_TRAIN_STEPS = 500 # Adjust this number. 500 steps is a quick test. 10,000+ would be better.
BATCH_SIZE = 4 # Lower if you run out of memory
# Projector Dimensions
IMAGE_EMBED_DIM = 768 # From uform3
TEXT_EMBED_DIM = 1024 # From Qwen1.5-0.5B
# Paths
PROJECTOR_WEIGHTS_PATH = "projector_weights.pt"
# --- 2. The Multimodal Model Architecture ---
# This class combines the frozen encoders with our trainable projector
class MultimodalModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Load and freeze the vision encoder
self.vision_encoder = AutoModel.from_pretrained(
config.IMAGE_ENCODER_ID, trust_remote_code=True
).eval() # .eval() is important
for param in self.vision_encoder.parameters():
param.requires_grad = False
# Load and freeze the language model
self.language_model = AutoModel.from_pretrained(
config.TEXT_MODEL_ID
).eval()
for param in self.language_model.parameters():
param.requires_grad = False
# Define our trainable projector
self.projector = nn.Sequential(
nn.Linear(config.IMAGE_EMBED_DIM, config.IMAGE_EMBED_DIM * 2),
nn.ReLU(),
nn.Linear(config.IMAGE_EMBED_DIM * 2, config.TEXT_EMBED_DIM)
)
def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
# 1. Get image embeddings from the vision encoder
# We need to process this to get a single vector per image
image_outputs = self.vision_encoder.get_image_features(pixel_values=pixel_values)
image_embeds = image_outputs
# 2. Project the image embeddings to match the text model's dimension
projected_image_embeds = self.projector(image_embeds)
# 3. Get text embeddings from the language model
text_embeds = self.language_model.get_input_embeddings()(input_ids)
# 4. Concatenate them: [Image Embedding, Text Embedding]
# The projected image embed acts as a "visual prefix"
inputs_embeds = torch.cat([projected_image_embeds.unsqueeze(1), text_embeds], dim=1)
# 5. Get language model outputs
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
return outputs
# --- 3. The Training Function ---
def train_projector(training_steps, learning_rate, batch_size, progress=gr.Progress()):
if not torch.cuda.is_available():
yield "Training requires a GPU. Please provision one for this Space."
return
device = "cuda"
config = Config()
config.NUM_TRAIN_STEPS = int(training_steps)
config.LR = float(learning_rate)
config.BATCH_SIZE = int(batch_size)
yield "Initializing models and tokenizers..."
# Load processors
image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token # Qwen doesn't have a pad token by default
# Instantiate the combined model
model = MultimodalModel(config).to(device)
# Load and preprocess the dataset
yield "Loading and preprocessing dataset (this may take a moment)..."
def preprocess(batch):
# We need to handle potential errors if an image fails to load
try:
images = [Image.open(f).convert("RGB") for f in batch['image_path']]
except Exception:
return {'pixel_values': None}
captions = batch['caption']
# Process images
image_inputs = image_processor(images=images, return_tensors="pt")
# Process text
text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt")
return {
'pixel_values': image_inputs['pixel_values'],
'input_ids': text_inputs['input_ids'],
'attention_mask': text_inputs['attention_mask']
}
# Use streaming to avoid downloading the whole dataset
dataset = load_dataset(config.DATASET_ID, streaming=True, split="train")
processed_dataset = dataset.map(preprocess, batched=True, batch_size=config.BATCH_SIZE)
# Filter out failed image loads
processed_dataset = processed_dataset.filter(lambda example: example['pixel_values'] is not None)
dataloader = DataLoader(processed_dataset.with_format("torch"), batch_size=config.BATCH_SIZE)
# Setup optimizer and scheduler
optimizer = AdamW(model.projector.parameters(), lr=config.LR)
scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=config.NUM_TRAIN_STEPS
)
# Training Loop
model.projector.train()
progress(0, desc="Starting Training")
global_step = 0
for batch in tqdm(dataloader, desc="Training Steps"):
if global_step >= config.NUM_TRAIN_STEPS:
break
pixel_values = batch['pixel_values'].to(device)
input_ids = batch['input_ids'].to(device)
# Prepare labels for language model loss calculation
labels = input_ids.clone()
# The visual part doesn't have a label
image_part_label = torch.full((labels.size(0), 1), -100, dtype=torch.long, device=device)
labels = torch.cat([image_part_label, labels], dim=1)
# Prepare attention mask for combined input
# We need to add a '1' for the image embedding
attention_mask = torch.cat([torch.ones_like(image_part_label), batch['attention_mask'].to(device)], dim=1)
# Forward pass
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# Backward pass
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if global_step % 10 == 0:
yield f"Step: {global_step}/{config.NUM_TRAIN_STEPS}, Loss: {loss.item():.4f}"
progress(global_step / config.NUM_TRAIN_STEPS)
global_step += 1
yield "Training finished. Saving projector weights..."
torch.save(model.projector.state_dict(), config.PROJECTOR_WEIGHTS_PATH)
yield f"Projector weights saved to {config.PROJECTOR_WEIGHTS_PATH}. You can now use the Inference tab."
# --- 4. The Inference Function ---
def run_inference(image_pil):
if not os.path.exists(Config.PROJECTOR_WEIGHTS_PATH):
return "Projector weights not found. Please train the model first using the 'Training' tab."
if image_pil is None:
return "Please upload an image."
device = "cuda" if torch.cuda.is_available() else "cpu"
config = Config()
# Load all components for inference
image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
model = MultimodalModel(config).to(device).eval()
# Load our trained projector weights
model.projector.load_state_dict(torch.load(config.PROJECTOR_WEIGHTS_PATH, map_location=device))
# Prepare the image
image_tensors = image_processor(images=[image_pil], return_tensors="pt")['pixel_values'].to(device)
# Prepare the prompt for the language model
prompt = "Describe this image in one sentence."
prompt_tokens = tokenizer(prompt, return_tensors="pt")
# Get image and text embeddings
with torch.no_grad():
image_embeds = model.vision_encoder.get_image_features(pixel_values=image_tensors)
projected_embeds = model.projector(image_embeds)
text_embeds = model.language_model.get_input_embeddings()(prompt_tokens.input_ids.to(device))
# Combine them to form the input for the generate function
inputs_embeds = torch.cat([projected_embeds.unsqueeze(1), text_embeds], dim=1)
# Generate text
output_ids = model.language_model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=50,
do_sample=False
)
# Decode and return the result
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# The output often includes the original prompt, so we can clean it up
cleaned_text = generated_text.replace(prompt, "").strip()
return cleaned_text
# --- 5. Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# Image Captioning Model Training and Inference")
gr.Markdown("Connects `uform3` (Vision) and `Qwen` (Language) by training a projector layer.")
with gr.Tab("Training"):
gr.Markdown("## Step 1: Train the Projector")
gr.Markdown("This will train a small neural network to translate image features into a format the language model can understand. **This requires a GPU and will take time.**")
steps_input = gr.Number(label="Number of Training Steps", value=Config.NUM_TRAIN_STEPS)
lr_input = gr.Number(label="Learning Rate", value=Config.LR)
batch_size_input = gr.Number(label="Batch Size (lower if you get OOM errors)", value=Config.BATCH_SIZE)
start_training_btn = gr.Button("Start Training")
training_status = gr.Textbox(label="Training Status", lines=10, interactive=False)
with gr.Tab("Inference"):
gr.Markdown("## Step 2: Describe an Image")
gr.Markdown("Upload an image to generate a description using your newly trained projector.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
caption_output = gr.Textbox(label="Generated Caption")
inference_btn = gr.Button("Generate Caption")
# Connect UI components to functions
start_training_btn.click(
fn=train_projector,
inputs=[steps_input, lr_input, batch_size_input],
outputs=[training_status]
)
inference_btn.click(
fn=run_inference,
inputs=[image_input],
outputs=[caption_output]
)
demo.launch()