|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AdamW, get_scheduler |
|
|
from datasets import load_dataset |
|
|
from PIL import Image |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
|
|
|
IMAGE_ENCODER_ID = "unum-cloud/uform3-image-text-english-large" |
|
|
TEXT_MODEL_ID = "Qwen/Qwen1.5-0.5B-Chat" |
|
|
|
|
|
|
|
|
DATASET_ID = "recastai/LAION-art-EN-improved-captions" |
|
|
|
|
|
|
|
|
LR = 5e-5 |
|
|
NUM_TRAIN_STEPS = 500 |
|
|
BATCH_SIZE = 4 |
|
|
|
|
|
|
|
|
IMAGE_EMBED_DIM = 768 |
|
|
TEXT_EMBED_DIM = 1024 |
|
|
|
|
|
|
|
|
PROJECTOR_WEIGHTS_PATH = "projector_weights.pt" |
|
|
|
|
|
|
|
|
|
|
|
class MultimodalModel(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_encoder = AutoModel.from_pretrained( |
|
|
config.IMAGE_ENCODER_ID, trust_remote_code=True |
|
|
).eval() |
|
|
for param in self.vision_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.language_model = AutoModel.from_pretrained( |
|
|
config.TEXT_MODEL_ID |
|
|
).eval() |
|
|
for param in self.language_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.projector = nn.Sequential( |
|
|
nn.Linear(config.IMAGE_EMBED_DIM, config.IMAGE_EMBED_DIM * 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.IMAGE_EMBED_DIM * 2, config.TEXT_EMBED_DIM) |
|
|
) |
|
|
|
|
|
def forward(self, pixel_values, input_ids, attention_mask=None, labels=None): |
|
|
|
|
|
|
|
|
image_outputs = self.vision_encoder.get_image_features(pixel_values=pixel_values) |
|
|
image_embeds = image_outputs |
|
|
|
|
|
|
|
|
projected_image_embeds = self.projector(image_embeds) |
|
|
|
|
|
|
|
|
text_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds = torch.cat([projected_image_embeds.unsqueeze(1), text_embeds], dim=1) |
|
|
|
|
|
|
|
|
outputs = self.language_model( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def train_projector(training_steps, learning_rate, batch_size, progress=gr.Progress()): |
|
|
if not torch.cuda.is_available(): |
|
|
yield "Training requires a GPU. Please provision one for this Space." |
|
|
return |
|
|
|
|
|
device = "cuda" |
|
|
config = Config() |
|
|
config.NUM_TRAIN_STEPS = int(training_steps) |
|
|
config.LR = float(learning_rate) |
|
|
config.BATCH_SIZE = int(batch_size) |
|
|
|
|
|
yield "Initializing models and tokenizers..." |
|
|
|
|
|
|
|
|
image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model = MultimodalModel(config).to(device) |
|
|
|
|
|
|
|
|
yield "Loading and preprocessing dataset (this may take a moment)..." |
|
|
|
|
|
def preprocess(batch): |
|
|
|
|
|
try: |
|
|
images = [Image.open(f).convert("RGB") for f in batch['image_path']] |
|
|
except Exception: |
|
|
return {'pixel_values': None} |
|
|
|
|
|
captions = batch['caption'] |
|
|
|
|
|
|
|
|
image_inputs = image_processor(images=images, return_tensors="pt") |
|
|
|
|
|
|
|
|
text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt") |
|
|
|
|
|
return { |
|
|
'pixel_values': image_inputs['pixel_values'], |
|
|
'input_ids': text_inputs['input_ids'], |
|
|
'attention_mask': text_inputs['attention_mask'] |
|
|
} |
|
|
|
|
|
|
|
|
dataset = load_dataset(config.DATASET_ID, streaming=True, split="train") |
|
|
processed_dataset = dataset.map(preprocess, batched=True, batch_size=config.BATCH_SIZE) |
|
|
|
|
|
|
|
|
processed_dataset = processed_dataset.filter(lambda example: example['pixel_values'] is not None) |
|
|
|
|
|
dataloader = DataLoader(processed_dataset.with_format("torch"), batch_size=config.BATCH_SIZE) |
|
|
|
|
|
|
|
|
optimizer = AdamW(model.projector.parameters(), lr=config.LR) |
|
|
scheduler = get_scheduler( |
|
|
"linear", |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=0, |
|
|
num_training_steps=config.NUM_TRAIN_STEPS |
|
|
) |
|
|
|
|
|
|
|
|
model.projector.train() |
|
|
progress(0, desc="Starting Training") |
|
|
|
|
|
global_step = 0 |
|
|
for batch in tqdm(dataloader, desc="Training Steps"): |
|
|
if global_step >= config.NUM_TRAIN_STEPS: |
|
|
break |
|
|
|
|
|
pixel_values = batch['pixel_values'].to(device) |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
|
|
|
image_part_label = torch.full((labels.size(0), 1), -100, dtype=torch.long, device=device) |
|
|
labels = torch.cat([image_part_label, labels], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
attention_mask = torch.cat([torch.ones_like(image_part_label), batch['attention_mask'].to(device)], dim=1) |
|
|
|
|
|
|
|
|
outputs = model( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
loss = outputs.loss |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
if global_step % 10 == 0: |
|
|
yield f"Step: {global_step}/{config.NUM_TRAIN_STEPS}, Loss: {loss.item():.4f}" |
|
|
progress(global_step / config.NUM_TRAIN_STEPS) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
yield "Training finished. Saving projector weights..." |
|
|
torch.save(model.projector.state_dict(), config.PROJECTOR_WEIGHTS_PATH) |
|
|
yield f"Projector weights saved to {config.PROJECTOR_WEIGHTS_PATH}. You can now use the Inference tab." |
|
|
|
|
|
|
|
|
|
|
|
def run_inference(image_pil): |
|
|
if not os.path.exists(Config.PROJECTOR_WEIGHTS_PATH): |
|
|
return "Projector weights not found. Please train the model first using the 'Training' tab." |
|
|
if image_pil is None: |
|
|
return "Please upload an image." |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
config = Config() |
|
|
|
|
|
|
|
|
image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID) |
|
|
model = MultimodalModel(config).to(device).eval() |
|
|
|
|
|
|
|
|
model.projector.load_state_dict(torch.load(config.PROJECTOR_WEIGHTS_PATH, map_location=device)) |
|
|
|
|
|
|
|
|
image_tensors = image_processor(images=[image_pil], return_tensors="pt")['pixel_values'].to(device) |
|
|
|
|
|
|
|
|
prompt = "Describe this image in one sentence." |
|
|
prompt_tokens = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embeds = model.vision_encoder.get_image_features(pixel_values=image_tensors) |
|
|
projected_embeds = model.projector(image_embeds) |
|
|
text_embeds = model.language_model.get_input_embeddings()(prompt_tokens.input_ids.to(device)) |
|
|
|
|
|
|
|
|
inputs_embeds = torch.cat([projected_embeds.unsqueeze(1), text_embeds], dim=1) |
|
|
|
|
|
|
|
|
output_ids = model.language_model.generate( |
|
|
inputs_embeds=inputs_embeds, |
|
|
max_new_tokens=50, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
cleaned_text = generated_text.replace(prompt, "").strip() |
|
|
return cleaned_text |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Image Captioning Model Training and Inference") |
|
|
gr.Markdown("Connects `uform3` (Vision) and `Qwen` (Language) by training a projector layer.") |
|
|
|
|
|
with gr.Tab("Training"): |
|
|
gr.Markdown("## Step 1: Train the Projector") |
|
|
gr.Markdown("This will train a small neural network to translate image features into a format the language model can understand. **This requires a GPU and will take time.**") |
|
|
|
|
|
steps_input = gr.Number(label="Number of Training Steps", value=Config.NUM_TRAIN_STEPS) |
|
|
lr_input = gr.Number(label="Learning Rate", value=Config.LR) |
|
|
batch_size_input = gr.Number(label="Batch Size (lower if you get OOM errors)", value=Config.BATCH_SIZE) |
|
|
|
|
|
start_training_btn = gr.Button("Start Training") |
|
|
training_status = gr.Textbox(label="Training Status", lines=10, interactive=False) |
|
|
|
|
|
with gr.Tab("Inference"): |
|
|
gr.Markdown("## Step 2: Describe an Image") |
|
|
gr.Markdown("Upload an image to generate a description using your newly trained projector.") |
|
|
|
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
caption_output = gr.Textbox(label="Generated Caption") |
|
|
|
|
|
inference_btn = gr.Button("Generate Caption") |
|
|
|
|
|
|
|
|
start_training_btn.click( |
|
|
fn=train_projector, |
|
|
inputs=[steps_input, lr_input, batch_size_input], |
|
|
outputs=[training_status] |
|
|
) |
|
|
|
|
|
inference_btn.click( |
|
|
fn=run_inference, |
|
|
inputs=[image_input], |
|
|
outputs=[caption_output] |
|
|
) |
|
|
|
|
|
demo.launch() |