sajofu commited on
Commit
f129010
·
verified ·
1 Parent(s): a7b92c4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor, AdamW, get_scheduler
6
+ from datasets import load_dataset
7
+ from PIL import Image
8
+ import os
9
+ from tqdm import tqdm
10
+
11
+ # --- 1. Configuration ---
12
+ # A simple class to hold our configuration
13
+ class Config:
14
+ # Model IDs
15
+ IMAGE_ENCODER_ID = "unum-cloud/uform3-image-text-english-large"
16
+ TEXT_MODEL_ID = "Qwen/Qwen1.5-0.5B-Chat"
17
+
18
+ # Dataset
19
+ DATASET_ID = "recastai/LAION-art-EN-improved-captions"
20
+
21
+ # Training Parameters
22
+ LR = 5e-5
23
+ NUM_TRAIN_STEPS = 500 # Adjust this number. 500 steps is a quick test. 10,000+ would be better.
24
+ BATCH_SIZE = 4 # Lower if you run out of memory
25
+
26
+ # Projector Dimensions
27
+ IMAGE_EMBED_DIM = 768 # From uform3
28
+ TEXT_EMBED_DIM = 1024 # From Qwen1.5-0.5B
29
+
30
+ # Paths
31
+ PROJECTOR_WEIGHTS_PATH = "projector_weights.pt"
32
+
33
+ # --- 2. The Multimodal Model Architecture ---
34
+ # This class combines the frozen encoders with our trainable projector
35
+ class MultimodalModel(nn.Module):
36
+ def __init__(self, config):
37
+ super().__init__()
38
+ self.config = config
39
+
40
+ # Load and freeze the vision encoder
41
+ self.vision_encoder = AutoModel.from_pretrained(
42
+ config.IMAGE_ENCODER_ID, trust_remote_code=True
43
+ ).eval() # .eval() is important
44
+ for param in self.vision_encoder.parameters():
45
+ param.requires_grad = False
46
+
47
+ # Load and freeze the language model
48
+ self.language_model = AutoModel.from_pretrained(
49
+ config.TEXT_MODEL_ID
50
+ ).eval()
51
+ for param in self.language_model.parameters():
52
+ param.requires_grad = False
53
+
54
+ # Define our trainable projector
55
+ self.projector = nn.Sequential(
56
+ nn.Linear(config.IMAGE_EMBED_DIM, config.IMAGE_EMBED_DIM * 2),
57
+ nn.ReLU(),
58
+ nn.Linear(config.IMAGE_EMBED_DIM * 2, config.TEXT_EMBED_DIM)
59
+ )
60
+
61
+ def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
62
+ # 1. Get image embeddings from the vision encoder
63
+ # We need to process this to get a single vector per image
64
+ image_outputs = self.vision_encoder.get_image_features(pixel_values=pixel_values)
65
+ image_embeds = image_outputs
66
+
67
+ # 2. Project the image embeddings to match the text model's dimension
68
+ projected_image_embeds = self.projector(image_embeds)
69
+
70
+ # 3. Get text embeddings from the language model
71
+ text_embeds = self.language_model.get_input_embeddings()(input_ids)
72
+
73
+ # 4. Concatenate them: [Image Embedding, Text Embedding]
74
+ # The projected image embed acts as a "visual prefix"
75
+ inputs_embeds = torch.cat([projected_image_embeds.unsqueeze(1), text_embeds], dim=1)
76
+
77
+ # 5. Get language model outputs
78
+ outputs = self.language_model(
79
+ inputs_embeds=inputs_embeds,
80
+ attention_mask=attention_mask,
81
+ labels=labels
82
+ )
83
+
84
+ return outputs
85
+
86
+ # --- 3. The Training Function ---
87
+ def train_projector(training_steps, learning_rate, batch_size, progress=gr.Progress()):
88
+ if not torch.cuda.is_available():
89
+ yield "Training requires a GPU. Please provision one for this Space."
90
+ return
91
+
92
+ device = "cuda"
93
+ config = Config()
94
+ config.NUM_TRAIN_STEPS = int(training_steps)
95
+ config.LR = float(learning_rate)
96
+ config.BATCH_SIZE = int(batch_size)
97
+
98
+ yield "Initializing models and tokenizers..."
99
+
100
+ # Load processors
101
+ image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
102
+ tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
103
+ tokenizer.pad_token = tokenizer.eos_token # Qwen doesn't have a pad token by default
104
+
105
+ # Instantiate the combined model
106
+ model = MultimodalModel(config).to(device)
107
+
108
+ # Load and preprocess the dataset
109
+ yield "Loading and preprocessing dataset (this may take a moment)..."
110
+
111
+ def preprocess(batch):
112
+ # We need to handle potential errors if an image fails to load
113
+ try:
114
+ images = [Image.open(f).convert("RGB") for f in batch['image_path']]
115
+ except Exception:
116
+ return {'pixel_values': None}
117
+
118
+ captions = batch['caption']
119
+
120
+ # Process images
121
+ image_inputs = image_processor(images=images, return_tensors="pt")
122
+
123
+ # Process text
124
+ text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt")
125
+
126
+ return {
127
+ 'pixel_values': image_inputs['pixel_values'],
128
+ 'input_ids': text_inputs['input_ids'],
129
+ 'attention_mask': text_inputs['attention_mask']
130
+ }
131
+
132
+ # Use streaming to avoid downloading the whole dataset
133
+ dataset = load_dataset(config.DATASET_ID, streaming=True, split="train")
134
+ processed_dataset = dataset.map(preprocess, batched=True, batch_size=config.BATCH_SIZE)
135
+
136
+ # Filter out failed image loads
137
+ processed_dataset = processed_dataset.filter(lambda example: example['pixel_values'] is not None)
138
+
139
+ dataloader = DataLoader(processed_dataset.with_format("torch"), batch_size=config.BATCH_SIZE)
140
+
141
+ # Setup optimizer and scheduler
142
+ optimizer = AdamW(model.projector.parameters(), lr=config.LR)
143
+ scheduler = get_scheduler(
144
+ "linear",
145
+ optimizer=optimizer,
146
+ num_warmup_steps=0,
147
+ num_training_steps=config.NUM_TRAIN_STEPS
148
+ )
149
+
150
+ # Training Loop
151
+ model.projector.train()
152
+ progress(0, desc="Starting Training")
153
+
154
+ global_step = 0
155
+ for batch in tqdm(dataloader, desc="Training Steps"):
156
+ if global_step >= config.NUM_TRAIN_STEPS:
157
+ break
158
+
159
+ pixel_values = batch['pixel_values'].to(device)
160
+ input_ids = batch['input_ids'].to(device)
161
+
162
+ # Prepare labels for language model loss calculation
163
+ labels = input_ids.clone()
164
+ # The visual part doesn't have a label
165
+ image_part_label = torch.full((labels.size(0), 1), -100, dtype=torch.long, device=device)
166
+ labels = torch.cat([image_part_label, labels], dim=1)
167
+
168
+ # Prepare attention mask for combined input
169
+ # We need to add a '1' for the image embedding
170
+ attention_mask = torch.cat([torch.ones_like(image_part_label), batch['attention_mask'].to(device)], dim=1)
171
+
172
+ # Forward pass
173
+ outputs = model(
174
+ pixel_values=pixel_values,
175
+ input_ids=input_ids,
176
+ attention_mask=attention_mask,
177
+ labels=labels
178
+ )
179
+
180
+ loss = outputs.loss
181
+
182
+ # Backward pass
183
+ loss.backward()
184
+ optimizer.step()
185
+ scheduler.step()
186
+ optimizer.zero_grad()
187
+
188
+ if global_step % 10 == 0:
189
+ yield f"Step: {global_step}/{config.NUM_TRAIN_STEPS}, Loss: {loss.item():.4f}"
190
+ progress(global_step / config.NUM_TRAIN_STEPS)
191
+
192
+ global_step += 1
193
+
194
+ yield "Training finished. Saving projector weights..."
195
+ torch.save(model.projector.state_dict(), config.PROJECTOR_WEIGHTS_PATH)
196
+ yield f"Projector weights saved to {config.PROJECTOR_WEIGHTS_PATH}. You can now use the Inference tab."
197
+
198
+
199
+ # --- 4. The Inference Function ---
200
+ def run_inference(image_pil):
201
+ if not os.path.exists(Config.PROJECTOR_WEIGHTS_PATH):
202
+ return "Projector weights not found. Please train the model first using the 'Training' tab."
203
+ if image_pil is None:
204
+ return "Please upload an image."
205
+
206
+ device = "cuda" if torch.cuda.is_available() else "cpu"
207
+ config = Config()
208
+
209
+ # Load all components for inference
210
+ image_processor = AutoProcessor.from_pretrained(config.IMAGE_ENCODER_ID, trust_remote_code=True)
211
+ tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_ID)
212
+ model = MultimodalModel(config).to(device).eval()
213
+
214
+ # Load our trained projector weights
215
+ model.projector.load_state_dict(torch.load(config.PROJECTOR_WEIGHTS_PATH, map_location=device))
216
+
217
+ # Prepare the image
218
+ image_tensors = image_processor(images=[image_pil], return_tensors="pt")['pixel_values'].to(device)
219
+
220
+ # Prepare the prompt for the language model
221
+ prompt = "Describe this image in one sentence."
222
+ prompt_tokens = tokenizer(prompt, return_tensors="pt")
223
+
224
+ # Get image and text embeddings
225
+ with torch.no_grad():
226
+ image_embeds = model.vision_encoder.get_image_features(pixel_values=image_tensors)
227
+ projected_embeds = model.projector(image_embeds)
228
+ text_embeds = model.language_model.get_input_embeddings()(prompt_tokens.input_ids.to(device))
229
+
230
+ # Combine them to form the input for the generate function
231
+ inputs_embeds = torch.cat([projected_embeds.unsqueeze(1), text_embeds], dim=1)
232
+
233
+ # Generate text
234
+ output_ids = model.language_model.generate(
235
+ inputs_embeds=inputs_embeds,
236
+ max_new_tokens=50,
237
+ do_sample=False
238
+ )
239
+
240
+ # Decode and return the result
241
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
242
+
243
+ # The output often includes the original prompt, so we can clean it up
244
+ cleaned_text = generated_text.replace(prompt, "").strip()
245
+ return cleaned_text
246
+
247
+
248
+ # --- 5. Gradio UI ---
249
+ with gr.Blocks() as demo:
250
+ gr.Markdown("# Image Captioning Model Training and Inference")
251
+ gr.Markdown("Connects `uform3` (Vision) and `Qwen` (Language) by training a projector layer.")
252
+
253
+ with gr.Tab("Training"):
254
+ gr.Markdown("## Step 1: Train the Projector")
255
+ gr.Markdown("This will train a small neural network to translate image features into a format the language model can understand. **This requires a GPU and will take time.**")
256
+
257
+ steps_input = gr.Number(label="Number of Training Steps", value=Config.NUM_TRAIN_STEPS)
258
+ lr_input = gr.Number(label="Learning Rate", value=Config.LR)
259
+ batch_size_input = gr.Number(label="Batch Size (lower if you get OOM errors)", value=Config.BATCH_SIZE)
260
+
261
+ start_training_btn = gr.Button("Start Training")
262
+ training_status = gr.Textbox(label="Training Status", lines=10, interactive=False)
263
+
264
+ with gr.Tab("Inference"):
265
+ gr.Markdown("## Step 2: Describe an Image")
266
+ gr.Markdown("Upload an image to generate a description using your newly trained projector.")
267
+
268
+ with gr.Row():
269
+ image_input = gr.Image(type="pil", label="Upload Image")
270
+ caption_output = gr.Textbox(label="Generated Caption")
271
+
272
+ inference_btn = gr.Button("Generate Caption")
273
+
274
+ # Connect UI components to functions
275
+ start_training_btn.click(
276
+ fn=train_projector,
277
+ inputs=[steps_input, lr_input, batch_size_input],
278
+ outputs=[training_status]
279
+ )
280
+
281
+ inference_btn.click(
282
+ fn=run_inference,
283
+ inputs=[image_input],
284
+ outputs=[caption_output]
285
+ )
286
+
287
+ demo.launch()