File size: 11,044 Bytes
f129010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AdamW, get_scheduler
from datasets import load_dataset
from PIL import Image
import os
from tqdm import tqdm

# --- 1. Configuration ---
# A simple class to hold our configuration
class Config:
    # Model IDs
    IMAGE_ENCODER_ID = "unum-cloud/uform3-image-text-english-large"
    TEXT_MODEL_ID = "Qwen/Qwen1.5-0.5B-Chat"
    
    # Dataset
    DATASET_ID = "recastai/LAION-art-EN-improved-captions"
    
    # Training Parameters
    LR = 5e-5
    NUM_TRAIN_STEPS = 500  # Adjust this number. 500 steps is a quick test. 10,000+ would be better.
    BATCH_SIZE = 4      # Lower if you run out of memory
    
    # Projector Dimensions
    IMAGE_EMBED_DIM = 768 # From uform3
    TEXT_EMBED_DIM = 1024 # From Qwen1.5-0.5B
    
    # Paths
    PROJECTOR_WEIGHTS_PATH = "projector_weights.pt"

# --- 2. The Multimodal Model Architecture ---
# This class combines the frozen encoders with our trainable projector
class MultimodalModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Load and freeze the vision encoder
        self.vision_encoder = AutoModel.from_pretrained(
            config.IMAGE_ENCODER_ID, trust_remote_code=True
        ).eval() # .eval() is important
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

        # Load and freeze the language model
        self.language_model = AutoModel.from_pretrained(
            config.TEXT_MODEL_ID
        ).eval()
        for param in self.language_model.parameters():
            param.requires_grad = False
            
        # Define our trainable projector
        self.projector = nn.Sequential(
            nn.Linear(config.IMAGE_EMBED_DIM, config.IMAGE_EMBED_DIM * 2),
            nn.ReLU(),
            nn.Linear(config.IMAGE_EMBED_DIM * 2, config.TEXT_EMBED_DIM)
        )

    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
        # 1. Get image embeddings from the vision encoder
        # We need to process this to get a single vector per image
        image_outputs = self.vision_encoder.get_image_features(pixel_values=pixel_values)
        image_embeds = image_outputs

        # 2. Project the image embeddings to match the text model's dimension
        projected_image_embeds = self.projector(image_embeds)
        
        # 3. Get text embeddings from the language model
        text_embeds = self.language_model.get_input_embeddings()(input_ids)

        # 4. Concatenate them: [Image Embedding, Text Embedding]
        # The projected image embed acts as a "visual prefix"
        inputs_embeds = torch.cat([projected_image_embeds.unsqueeze(1), text_embeds], dim=1)
        
        # 5. Get language model outputs
        outputs = self.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        
        return outputs

# --- 3. The Training Function ---
def train_projector(training_steps, learning_rate, batch_size, progress=gr.Progress()):
    if not torch.cuda.is_available():
        yield "Training requires a GPU. Please provision one for this Space."
        return

    device = "cuda"
    config = Config()
    config.NUM_TRAIN_STEPS = int(training_steps)
    config.LR = float(learning_rate)
    config.BATCH_SIZE = int(batch_size)
    
    yield "Initializing models and tokenizers..."
    
    # Load processors
    image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
    tokenizer.pad_token = tokenizer.eos_token # Qwen doesn't have a pad token by default

    # Instantiate the combined model
    model = MultimodalModel(config).to(device)
    
    # Load and preprocess the dataset
    yield "Loading and preprocessing dataset (this may take a moment)..."
    
    def preprocess(batch):
        # We need to handle potential errors if an image fails to load
        try:
            images = [Image.open(f).convert("RGB") for f in batch['image_path']]
        except Exception:
            return {'pixel_values': None}
            
        captions = batch['caption']
        
        # Process images
        image_inputs = image_processor(images=images, return_tensors="pt")
        
        # Process text
        text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt")
        
        return {
            'pixel_values': image_inputs['pixel_values'],
            'input_ids': text_inputs['input_ids'],
            'attention_mask': text_inputs['attention_mask']
        }

    # Use streaming to avoid downloading the whole dataset
    dataset = load_dataset(config.DATASET_ID, streaming=True, split="train")
    processed_dataset = dataset.map(preprocess, batched=True, batch_size=config.BATCH_SIZE)
    
    # Filter out failed image loads
    processed_dataset = processed_dataset.filter(lambda example: example['pixel_values'] is not None)
    
    dataloader = DataLoader(processed_dataset.with_format("torch"), batch_size=config.BATCH_SIZE)

    # Setup optimizer and scheduler
    optimizer = AdamW(model.projector.parameters(), lr=config.LR)
    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=config.NUM_TRAIN_STEPS
    )

    # Training Loop
    model.projector.train()
    progress(0, desc="Starting Training")
    
    global_step = 0
    for batch in tqdm(dataloader, desc="Training Steps"):
        if global_step >= config.NUM_TRAIN_STEPS:
            break
        
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        
        # Prepare labels for language model loss calculation
        labels = input_ids.clone()
        # The visual part doesn't have a label
        image_part_label = torch.full((labels.size(0), 1), -100, dtype=torch.long, device=device)
        labels = torch.cat([image_part_label, labels], dim=1)

        # Prepare attention mask for combined input
        # We need to add a '1' for the image embedding
        attention_mask = torch.cat([torch.ones_like(image_part_label), batch['attention_mask'].to(device)], dim=1)
        
        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        if global_step % 10 == 0:
            yield f"Step: {global_step}/{config.NUM_TRAIN_STEPS}, Loss: {loss.item():.4f}"
            progress(global_step / config.NUM_TRAIN_STEPS)
        
        global_step += 1

    yield "Training finished. Saving projector weights..."
    torch.save(model.projector.state_dict(), config.PROJECTOR_WEIGHTS_PATH)
    yield f"Projector weights saved to {config.PROJECTOR_WEIGHTS_PATH}. You can now use the Inference tab."


# --- 4. The Inference Function ---
def run_inference(image_pil):
    if not os.path.exists(Config.PROJECTOR_WEIGHTS_PATH):
        return "Projector weights not found. Please train the model first using the 'Training' tab."
    if image_pil is None:
        return "Please upload an image."

    device = "cuda" if torch.cuda.is_available() else "cpu"
    config = Config()
    
    # Load all components for inference
    image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
    model = MultimodalModel(config).to(device).eval()
    
    # Load our trained projector weights
    model.projector.load_state_dict(torch.load(config.PROJECTOR_WEIGHTS_PATH, map_location=device))
    
    # Prepare the image
    image_tensors = image_processor(images=[image_pil], return_tensors="pt")['pixel_values'].to(device)
    
    # Prepare the prompt for the language model
    prompt = "Describe this image in one sentence."
    prompt_tokens = tokenizer(prompt, return_tensors="pt")

    # Get image and text embeddings
    with torch.no_grad():
        image_embeds = model.vision_encoder.get_image_features(pixel_values=image_tensors)
        projected_embeds = model.projector(image_embeds)
        text_embeds = model.language_model.get_input_embeddings()(prompt_tokens.input_ids.to(device))
        
        # Combine them to form the input for the generate function
        inputs_embeds = torch.cat([projected_embeds.unsqueeze(1), text_embeds], dim=1)

        # Generate text
        output_ids = model.language_model.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=50,
            do_sample=False
        )

    # Decode and return the result
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # The output often includes the original prompt, so we can clean it up
    cleaned_text = generated_text.replace(prompt, "").strip()
    return cleaned_text


# --- 5. Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("# Image Captioning Model Training and Inference")
    gr.Markdown("Connects `uform3` (Vision) and `Qwen` (Language) by training a projector layer.")
    
    with gr.Tab("Training"):
        gr.Markdown("## Step 1: Train the Projector")
        gr.Markdown("This will train a small neural network to translate image features into a format the language model can understand. **This requires a GPU and will take time.**")
        
        steps_input = gr.Number(label="Number of Training Steps", value=Config.NUM_TRAIN_STEPS)
        lr_input = gr.Number(label="Learning Rate", value=Config.LR)
        batch_size_input = gr.Number(label="Batch Size (lower if you get OOM errors)", value=Config.BATCH_SIZE)
        
        start_training_btn = gr.Button("Start Training")
        training_status = gr.Textbox(label="Training Status", lines=10, interactive=False)
        
    with gr.Tab("Inference"):
        gr.Markdown("## Step 2: Describe an Image")
        gr.Markdown("Upload an image to generate a description using your newly trained projector.")
        
        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Image")
            caption_output = gr.Textbox(label="Generated Caption")
        
        inference_btn = gr.Button("Generate Caption")

    # Connect UI components to functions
    start_training_btn.click(
        fn=train_projector,
        inputs=[steps_input, lr_input, batch_size_input],
        outputs=[training_status]
    )
    
    inference_btn.click(
        fn=run_inference,
        inputs=[image_input],
        outputs=[caption_output]
    )

demo.launch()