import gradio as gr import torch import torch.nn as nn from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer, AutoProcessor, AdamW, get_scheduler from datasets import load_dataset from PIL import Image import os from tqdm import tqdm # --- 1. Configuration --- # A simple class to hold our configuration class Config: # Model IDs IMAGE_ENCODER_ID = "unum-cloud/uform3-image-text-english-large" TEXT_MODEL_ID = "Qwen/Qwen1.5-0.5B-Chat" # Dataset DATASET_ID = "recastai/LAION-art-EN-improved-captions" # Training Parameters LR = 5e-5 NUM_TRAIN_STEPS = 500 # Adjust this number. 500 steps is a quick test. 10,000+ would be better. BATCH_SIZE = 4 # Lower if you run out of memory # Projector Dimensions IMAGE_EMBED_DIM = 768 # From uform3 TEXT_EMBED_DIM = 1024 # From Qwen1.5-0.5B # Paths PROJECTOR_WEIGHTS_PATH = "projector_weights.pt" # --- 2. The Multimodal Model Architecture --- # This class combines the frozen encoders with our trainable projector class MultimodalModel(nn.Module): def __init__(self, config): super().__init__() self.config = config # Load and freeze the vision encoder self.vision_encoder = AutoModel.from_pretrained( config.IMAGE_ENCODER_ID, trust_remote_code=True ).eval() # .eval() is important for param in self.vision_encoder.parameters(): param.requires_grad = False # Load and freeze the language model self.language_model = AutoModel.from_pretrained( config.TEXT_MODEL_ID ).eval() for param in self.language_model.parameters(): param.requires_grad = False # Define our trainable projector self.projector = nn.Sequential( nn.Linear(config.IMAGE_EMBED_DIM, config.IMAGE_EMBED_DIM * 2), nn.ReLU(), nn.Linear(config.IMAGE_EMBED_DIM * 2, config.TEXT_EMBED_DIM) ) def forward(self, pixel_values, input_ids, attention_mask=None, labels=None): # 1. Get image embeddings from the vision encoder # We need to process this to get a single vector per image image_outputs = self.vision_encoder.get_image_features(pixel_values=pixel_values) image_embeds = image_outputs # 2. Project the image embeddings to match the text model's dimension projected_image_embeds = self.projector(image_embeds) # 3. Get text embeddings from the language model text_embeds = self.language_model.get_input_embeddings()(input_ids) # 4. Concatenate them: [Image Embedding, Text Embedding] # The projected image embed acts as a "visual prefix" inputs_embeds = torch.cat([projected_image_embeds.unsqueeze(1), text_embeds], dim=1) # 5. Get language model outputs outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels ) return outputs # --- 3. The Training Function --- def train_projector(training_steps, learning_rate, batch_size, progress=gr.Progress()): if not torch.cuda.is_available(): yield "Training requires a GPU. Please provision one for this Space." return device = "cuda" config = Config() config.NUM_TRAIN_STEPS = int(training_steps) config.LR = float(learning_rate) config.BATCH_SIZE = int(batch_size) yield "Initializing models and tokenizers..." # Load processors image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID) tokenizer.pad_token = tokenizer.eos_token # Qwen doesn't have a pad token by default # Instantiate the combined model model = MultimodalModel(config).to(device) # Load and preprocess the dataset yield "Loading and preprocessing dataset (this may take a moment)..." def preprocess(batch): # We need to handle potential errors if an image fails to load try: images = [Image.open(f).convert("RGB") for f in batch['image_path']] except Exception: return {'pixel_values': None} captions = batch['caption'] # Process images image_inputs = image_processor(images=images, return_tensors="pt") # Process text text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt") return { 'pixel_values': image_inputs['pixel_values'], 'input_ids': text_inputs['input_ids'], 'attention_mask': text_inputs['attention_mask'] } # Use streaming to avoid downloading the whole dataset dataset = load_dataset(config.DATASET_ID, streaming=True, split="train") processed_dataset = dataset.map(preprocess, batched=True, batch_size=config.BATCH_SIZE) # Filter out failed image loads processed_dataset = processed_dataset.filter(lambda example: example['pixel_values'] is not None) dataloader = DataLoader(processed_dataset.with_format("torch"), batch_size=config.BATCH_SIZE) # Setup optimizer and scheduler optimizer = AdamW(model.projector.parameters(), lr=config.LR) scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=config.NUM_TRAIN_STEPS ) # Training Loop model.projector.train() progress(0, desc="Starting Training") global_step = 0 for batch in tqdm(dataloader, desc="Training Steps"): if global_step >= config.NUM_TRAIN_STEPS: break pixel_values = batch['pixel_values'].to(device) input_ids = batch['input_ids'].to(device) # Prepare labels for language model loss calculation labels = input_ids.clone() # The visual part doesn't have a label image_part_label = torch.full((labels.size(0), 1), -100, dtype=torch.long, device=device) labels = torch.cat([image_part_label, labels], dim=1) # Prepare attention mask for combined input # We need to add a '1' for the image embedding attention_mask = torch.cat([torch.ones_like(image_part_label), batch['attention_mask'].to(device)], dim=1) # Forward pass outputs = model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss # Backward pass loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() if global_step % 10 == 0: yield f"Step: {global_step}/{config.NUM_TRAIN_STEPS}, Loss: {loss.item():.4f}" progress(global_step / config.NUM_TRAIN_STEPS) global_step += 1 yield "Training finished. Saving projector weights..." torch.save(model.projector.state_dict(), config.PROJECTOR_WEIGHTS_PATH) yield f"Projector weights saved to {config.PROJECTOR_WEIGHTS_PATH}. You can now use the Inference tab." # --- 4. The Inference Function --- def run_inference(image_pil): if not os.path.exists(Config.PROJECTOR_WEIGHTS_PATH): return "Projector weights not found. Please train the model first using the 'Training' tab." if image_pil is None: return "Please upload an image." device = "cuda" if torch.cuda.is_available() else "cpu" config = Config() # Load all components for inference image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID) model = MultimodalModel(config).to(device).eval() # Load our trained projector weights model.projector.load_state_dict(torch.load(config.PROJECTOR_WEIGHTS_PATH, map_location=device)) # Prepare the image image_tensors = image_processor(images=[image_pil], return_tensors="pt")['pixel_values'].to(device) # Prepare the prompt for the language model prompt = "Describe this image in one sentence." prompt_tokens = tokenizer(prompt, return_tensors="pt") # Get image and text embeddings with torch.no_grad(): image_embeds = model.vision_encoder.get_image_features(pixel_values=image_tensors) projected_embeds = model.projector(image_embeds) text_embeds = model.language_model.get_input_embeddings()(prompt_tokens.input_ids.to(device)) # Combine them to form the input for the generate function inputs_embeds = torch.cat([projected_embeds.unsqueeze(1), text_embeds], dim=1) # Generate text output_ids = model.language_model.generate( inputs_embeds=inputs_embeds, max_new_tokens=50, do_sample=False ) # Decode and return the result generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) # The output often includes the original prompt, so we can clean it up cleaned_text = generated_text.replace(prompt, "").strip() return cleaned_text # --- 5. Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# Image Captioning Model Training and Inference") gr.Markdown("Connects `uform3` (Vision) and `Qwen` (Language) by training a projector layer.") with gr.Tab("Training"): gr.Markdown("## Step 1: Train the Projector") gr.Markdown("This will train a small neural network to translate image features into a format the language model can understand. **This requires a GPU and will take time.**") steps_input = gr.Number(label="Number of Training Steps", value=Config.NUM_TRAIN_STEPS) lr_input = gr.Number(label="Learning Rate", value=Config.LR) batch_size_input = gr.Number(label="Batch Size (lower if you get OOM errors)", value=Config.BATCH_SIZE) start_training_btn = gr.Button("Start Training") training_status = gr.Textbox(label="Training Status", lines=10, interactive=False) with gr.Tab("Inference"): gr.Markdown("## Step 2: Describe an Image") gr.Markdown("Upload an image to generate a description using your newly trained projector.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") caption_output = gr.Textbox(label="Generated Caption") inference_btn = gr.Button("Generate Caption") # Connect UI components to functions start_training_btn.click( fn=train_projector, inputs=[steps_input, lr_input, batch_size_input], outputs=[training_status] ) inference_btn.click( fn=run_inference, inputs=[image_input], outputs=[caption_output] ) demo.launch()