LPX55 commited on
Commit
dd8a1e0
·
verified ·
1 Parent(s): 09d4d06

review: local prompt inf. (#1)

Browse files

- review: local prompt inf. (f8c747170c82f9b807d1c4f08dc22a8c1b3ebaa2)

Files changed (1) hide show
  1. app_local.py +168 -123
app_local.py CHANGED
@@ -9,6 +9,7 @@ from diffusers.utils import is_xformers_available
9
  import os
10
  import re
11
  import gc
 
12
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
 
14
  #############################
@@ -78,24 +79,31 @@ Please provide the rewritten instruction in a clean `json` format as:
78
 
79
  def extract_json_response(model_output: str) -> str:
80
  """Extract rewritten instruction from potentially messy JSON output"""
 
 
 
81
  try:
82
  # Try to find the JSON portion in the output
83
  start_idx = model_output.find('{')
84
- end_idx = model_output.rfind('}') + 1
85
- if start_idx == -1 or end_idx == 0:
86
  return None
87
 
 
 
 
88
  json_str = model_output[start_idx:end_idx]
89
- # Clean up common formatting issues
90
- json_str = re.sub(r'(?<!")\b(\w+)\b(?=":)', r'"\1"', json_str) # Add quotes to keys
91
- json_str = re.sub(r':\s*([^"{\[]|true|false|null)', r': "\1"', json_str) # Add quotes to values
 
92
 
93
  # Parse JSON
94
  data = json.loads(json_str)
95
 
96
  # Extract rewritten prompt from possible key variations
97
  possible_keys = [
98
- "Rewritten", "rewritten", "Rewrited", "rewrited",
99
  "Output", "output", "Enhanced", "enhanced"
100
  ]
101
  for key in possible_keys:
@@ -105,16 +113,23 @@ def extract_json_response(model_output: str) -> str:
105
  # Try nested path
106
  if "Response" in data and "Rewritten" in data["Response"]:
107
  return data["Response"]["Rewritten"].strip()
 
 
 
 
 
 
108
 
109
- # Fallback to direct extraction
110
- for value in data.values():
111
- if isinstance(value, str) and 10 < len(value) < 500:
112
- return value.strip()
113
 
114
- except Exception:
115
- pass
116
- return None
117
 
 
 
118
  def polish_prompt(original_prompt: str) -> str:
119
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
120
  load_rewriter()
@@ -136,11 +151,11 @@ def polish_prompt(original_prompt: str) -> str:
136
  with torch.no_grad():
137
  generated_ids = rewriter_model.generate(
138
  **model_inputs,
139
- max_new_tokens=256, # Maintain token count for good JSON generation
140
  do_sample=True,
141
- temperature=0.6,
142
  top_p=0.9,
143
- no_repeat_ngram_size=2,
144
  pad_token_id=rewriter_tokenizer.eos_token_id
145
  )
146
 
@@ -150,36 +165,50 @@ def polish_prompt(original_prompt: str) -> str:
150
  skip_special_tokens=True
151
  ).strip()
152
 
 
 
 
 
 
 
 
153
  # Try to extract JSON content
154
- rewritten_prompt = extract_json_response(enhanced)
155
 
156
  if rewritten_prompt:
157
- # Clean up substitutions from the JSON output
158
- rewritten_prompt = re.sub(r'(Replace|Change|Add) "([^"]*)"', r'\1 \2', rewritten_prompt)
159
- rewritten_prompt = rewritten_prompt.replace('\\"', '"')
160
  return rewritten_prompt
161
 
162
  # Fallback cleanup if JSON extraction fails
163
- print(f"⚠️ JSON extraction failed, using raw output: {enhanced}")
164
- fallback = re.sub(r'```.*?```', '', enhanced, flags=re.DOTALL) # Remove code blocks
165
- fallback = re.sub(r'[\{\}\[\]"]', '', fallback) # Remove JSON artifacts
166
- fallback = fallback.split('\n')[0] # Take first line
 
 
 
 
 
167
 
168
- # Try to extract before colon separator
169
- if ': ' in fallback:
170
- return fallback.split(': ')[1].strip()
 
 
171
 
172
- return fallback.strip()
173
 
174
  # Load main image editing pipeline
175
  pipe = QwenImageEditPipeline.from_pretrained(
176
- "Qwen/Qwen-Image-Edit",
177
  torch_dtype=dtype
178
  ).to(device)
179
 
180
  # Load LoRA weights for acceleration
181
  pipe.load_lora_weights(
182
- "lightx2v/Qwen-Image-Lightning",
183
  weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
184
  )
185
  pipe.fuse_lora()
@@ -189,7 +218,6 @@ if is_xformers_available():
189
  else:
190
  print("xformers not available")
191
 
192
-
193
  def unload_rewriter():
194
  """Clear enhancement model from memory"""
195
  global rewriter_tokenizer, rewriter_model
@@ -206,7 +234,7 @@ def infer(
206
  prompt,
207
  seed=42,
208
  randomize_seed=False,
209
- true_guidance_scale=4.0,
210
  num_inference_steps=8,
211
  rewrite_prompt=False,
212
  num_images_per_prompt=1,
@@ -220,19 +248,19 @@ def infer(
220
  try:
221
  enhanced_instruction = polish_prompt(original_prompt)
222
  prompt_info = (
223
- f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
224
  f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
225
  f"<p><strong>Original:</strong> {original_prompt}</p>"
226
- f"<p><strong>Enhanced:</strong> {enhanced_instruction}</p>"
227
  f"</div>"
228
  )
229
  prompt = enhanced_instruction
230
  except Exception as e:
231
  gr.Warning(f"Prompt enhancement failed: {str(e)}")
232
  prompt_info = (
233
- f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
234
  f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
235
- f"<p>Using original prompt. Error: {str(e)}</p>"
236
  f"</div>"
237
  )
238
  else:
@@ -247,9 +275,7 @@ def infer(
247
  unload_rewriter()
248
 
249
  # Set seed for reproducibility
250
- seed_val = seed
251
- if randomize_seed:
252
- seed_val = random.randint(0, 2**32 - 1)
253
  generator = torch.Generator(device=device).manual_seed(seed_val)
254
 
255
  try:
@@ -263,121 +289,140 @@ def infer(
263
  true_cfg_scale=true_guidance_scale,
264
  num_images_per_prompt=num_images_per_prompt
265
  ).images
 
 
266
  except Exception as e:
267
  gr.Error(f"Image generation failed: {str(e)}")
268
- prompt_info = (
269
- f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
270
- f"<h4 style='margin-top: 0;'><strong>⚠️ Error:</strong> {str(e)}</h4>"
 
271
  f"</div>"
272
  )
273
- return [], seed_val, prompt_info
274
-
275
- return edited_images, seed_val, prompt_info
276
 
277
  MAX_SEED = np.iinfo(np.int32).max
278
- examples = [
279
- "Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.",
280
- "Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.",
281
- "Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.",
282
- "Remove the blue sky and replace it with a dark night cityscape.",
283
- """Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used and position it at top left."""
284
- ]
285
 
286
- with gr.Blocks(title="Qwen Image Editor Fast") as demo:
287
  gr.Markdown("""
288
- <div style="text-align: center;">
289
- <h1>⚡️ Qwen-Image-Edit Lightning Fast 8-STEP</h1>
290
- <p>8-step image editing with lightx2v's LoRA and local prompt enhancement</p>
291
- <p>🚧 Work in progress, further improvements coming soon.</p>
292
  </div>
293
  """)
294
-
295
- with gr.Row():
296
  # Input Column
297
- with gr.Column():
298
- input_image = gr.Image(label="Input Image", type="pil")
299
- prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g. Add a dog to the right side", lines=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- with gr.Accordion("Advanced Settings", open=False):
302
- gr.Markdown("### Generation Parameters")
303
  with gr.Row():
304
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
305
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
306
  with gr.Row():
307
  true_guidance_scale = gr.Slider(
308
- label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=4.0
 
 
 
 
309
  )
310
  num_inference_steps = gr.Slider(
311
- label="Inference Steps", minimum=4, maximum=16, step=1, value=8
312
- )
313
- num_images_per_prompt = gr.Slider(
314
- label="Output Images", minimum=1, maximum=4, step=1, value=2
 
315
  )
 
 
 
 
 
 
 
316
 
317
- rewrite_toggle = gr.Checkbox(
318
- label="Enable AI Prompt Enhancement",
319
- value=True
320
- )
321
-
322
- run_button = gr.Button("Generate Edits", variant="primary")
323
-
324
  # Output Column
325
- with gr.Column():
326
  result = gr.Gallery(
327
- label="Output Images",
328
- columns=lambda x: 2 if x > 1 else 1,
329
- object_fit="contain",
330
- height="auto"
 
331
  )
332
  prompt_info = gr.HTML(
333
- "<div style='margin-top:20px; padding:15px; border-radius:8px; background:#f8f9fa'>"
334
- "<p>Prompt details will appear here after generation</p></div>"
335
  )
336
 
337
- # gr.Examples(
338
- # examples=examples,
339
- # inputs=[prompt],
340
- # label="Try These Examples",
341
- # cache_examples=True
342
- # )
 
 
 
 
 
343
 
344
- # Main processing
345
- run_event = run_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  fn=infer,
347
- inputs=[
348
- input_image,
349
- prompt,
350
- seed,
351
- randomize_seed,
352
- true_guidance_scale,
353
- num_inference_steps,
354
- rewrite_toggle,
355
- num_images_per_prompt
356
- ],
357
- outputs=[result, seed, prompt_info]
358
  )
359
 
360
  prompt.submit(
361
  fn=infer,
362
- inputs=[
363
- input_image,
364
- prompt,
365
- seed,
366
- randomize_seed,
367
- true_guidance_scale,
368
- num_inference_steps,
369
- rewrite_toggle,
370
- num_images_per_prompt
371
- ],
372
- outputs=[result, seed, prompt_info]
373
- )
374
-
375
- # Vectorize prompt info visibility
376
- run_event.then(
377
- fn=lambda: gr.update(visible=True),
378
- inputs=None,
379
- outputs=[prompt_info],
380
- queue=False
381
  )
382
 
383
  if __name__ == "__main__":
 
9
  import os
10
  import re
11
  import gc
12
+ import json # Added json import
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
14
 
15
  #############################
 
79
 
80
  def extract_json_response(model_output: str) -> str:
81
  """Extract rewritten instruction from potentially messy JSON output"""
82
+ # New: Remove code block markers first
83
+ model_output = re.sub(r'```(?:json)?\s*', '', model_output)
84
+
85
  try:
86
  # Try to find the JSON portion in the output
87
  start_idx = model_output.find('{')
88
+ end_idx = model_output.rfind('}')
89
+ if start_idx == -1 or end_idx == -1:
90
  return None
91
 
92
+ # Expand to the full object including outer braces
93
+ end_idx += 1 # Include the closing brace
94
+
95
  json_str = model_output[start_idx:end_idx]
96
+
97
+ # Improved quote handling for values
98
+ json_str = re.sub(r'(\w+)\s*:', r'"\1":', json_str) # Quote keys
99
+ json_str = re.sub(r':\s*([^"\s{[]+)', r': "\1"', json_str) # Quote unquoted string values
100
 
101
  # Parse JSON
102
  data = json.loads(json_str)
103
 
104
  # Extract rewritten prompt from possible key variations
105
  possible_keys = [
106
+ "Rewritten", "rewritten", "Rewrited", "rewrited",
107
  "Output", "output", "Enhanced", "enhanced"
108
  ]
109
  for key in possible_keys:
 
113
  # Try nested path
114
  if "Response" in data and "Rewritten" in data["Response"]:
115
  return data["Response"]["Rewritten"].strip()
116
+
117
+ # Handle nested JSON objects (additional protection)
118
+ if isinstance(data, dict):
119
+ for value in data.values():
120
+ if isinstance(value, dict) and "Rewritten" in value:
121
+ return value["Rewritten"].strip()
122
 
123
+ # Try to find any string value that looks like an instruction
124
+ str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
125
+ if str_values:
126
+ return str_values[0].strip()
127
 
128
+ except Exception as e:
129
+ print(f"JSON parse error: {str(e)}")
 
130
 
131
+ return None
132
+
133
  def polish_prompt(original_prompt: str) -> str:
134
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
135
  load_rewriter()
 
151
  with torch.no_grad():
152
  generated_ids = rewriter_model.generate(
153
  **model_inputs,
154
+ max_new_tokens=150, # Reduced for better quality
155
  do_sample=True,
156
+ temperature=0.4, # Less creative but more focused
157
  top_p=0.9,
158
+ no_repeat_ngram_size=3,
159
  pad_token_id=rewriter_tokenizer.eos_token_id
160
  )
161
 
 
165
  skip_special_tokens=True
166
  ).strip()
167
 
168
+ # New: Last-resort JSON content extraction
169
+ json_str = enhanced
170
+ if '```' in enhanced:
171
+ parts = enhanced.split('```')
172
+ if len(parts) >= 3:
173
+ json_str = parts[1] # Take content between first set of ```
174
+
175
  # Try to extract JSON content
176
+ rewritten_prompt = extract_json_response(json_str if '```' in enhanced else enhanced)
177
 
178
  if rewritten_prompt:
179
+ # Clean up remaining artifacts
180
+ rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
181
+ rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
182
  return rewritten_prompt
183
 
184
  # Fallback cleanup if JSON extraction fails
185
+ if '```' in enhanced:
186
+ # Extract content from code blocks
187
+ parts = enhanced.split('```')
188
+ if len(parts) >= 3:
189
+ rewritten_prompt = parts[1].strip()
190
+ else:
191
+ rewritten_prompt = enhanced
192
+ else:
193
+ rewritten_prompt = enhanced
194
 
195
+ # Improved cleaning of fallback output
196
+ rewritten_prompt = re.sub(r'.*{.*}.*', '', rewritten_prompt)
197
+ rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
198
+ if ': ' in rewritten_prompt:
199
+ rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
200
 
201
+ return rewritten_prompt[:200] # Ensure reasonable length
202
 
203
  # Load main image editing pipeline
204
  pipe = QwenImageEditPipeline.from_pretrained(
205
+ "Qwen/Qwen-Image-Edit",
206
  torch_dtype=dtype
207
  ).to(device)
208
 
209
  # Load LoRA weights for acceleration
210
  pipe.load_lora_weights(
211
+ "lightx2v/Qwen-Image-Lightning",
212
  weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
213
  )
214
  pipe.fuse_lora()
 
218
  else:
219
  print("xformers not available")
220
 
 
221
  def unload_rewriter():
222
  """Clear enhancement model from memory"""
223
  global rewriter_tokenizer, rewriter_model
 
234
  prompt,
235
  seed=42,
236
  randomize_seed=False,
237
+ true_guidance_scale=1.0,
238
  num_inference_steps=8,
239
  rewrite_prompt=False,
240
  num_images_per_prompt=1,
 
248
  try:
249
  enhanced_instruction = polish_prompt(original_prompt)
250
  prompt_info = (
251
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
252
  f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
253
  f"<p><strong>Original:</strong> {original_prompt}</p>"
254
+ f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
255
  f"</div>"
256
  )
257
  prompt = enhanced_instruction
258
  except Exception as e:
259
  gr.Warning(f"Prompt enhancement failed: {str(e)}")
260
  prompt_info = (
261
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
262
  f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
263
+ f"<p>Using original prompt. Error: {str(e)[:100]}</p>"
264
  f"</div>"
265
  )
266
  else:
 
275
  unload_rewriter()
276
 
277
  # Set seed for reproducibility
278
+ seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
 
 
279
  generator = torch.Generator(device=device).manual_seed(seed_val)
280
 
281
  try:
 
289
  true_cfg_scale=true_guidance_scale,
290
  num_images_per_prompt=num_images_per_prompt
291
  ).images
292
+ return edited_images, seed_val, prompt_info
293
+
294
  except Exception as e:
295
  gr.Error(f"Image generation failed: {str(e)}")
296
+ return [], seed_val, (
297
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
298
+ f"<h4 style='margin-top: 0;'>⚠️ Processing Error</h4>"
299
+ f"<p>{str(e)[:200]}</p>"
300
  f"</div>"
301
  )
 
 
 
302
 
303
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
304
 
305
+ with gr.Blocks(title="Qwen Image Editor Fast", css=".gr-gallery {min-height: 300px}") as demo:
306
  gr.Markdown("""
307
+ <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
308
+ <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1>
309
+ <p>8-step inferencing Local Prompt Enhancement H200 Optimized</p>
 
310
  </div>
311
  """)
312
+
313
+ with gr.Row(equal_height=True):
314
  # Input Column
315
+ with gr.Column(scale=1):
316
+ input_image = gr.Image(
317
+ label="Source Image",
318
+ type="pil",
319
+ height=300
320
+ )
321
+ prompt = gr.Textbox(
322
+ label="Edit Instructions",
323
+ placeholder="e.g. Replace the background with a beach sunset...",
324
+ lines=2,
325
+ max_lines=4
326
+ )
327
+
328
+ with gr.Row():
329
+ rewrite_toggle = gr.Checkbox(
330
+ label="Enable Prompt Enhancement",
331
+ value=True,
332
+ interactive=True
333
+ )
334
+ run_button = gr.Button(
335
+ "Generate Edits",
336
+ variant="primary",
337
+ min_width=120
338
+ )
339
 
340
+ with gr.Accordion("Advanced Parameters", open=False):
 
341
  with gr.Row():
342
+ seed = gr.Slider(
343
+ label="Seed",
344
+ min=0,
345
+ max=MAX_SEED,
346
+ step=1,
347
+ value=42
348
+ )
349
+ randomize_seed = gr.Checkbox(
350
+ label="Random Seed",
351
+ value=True
352
+ )
353
  with gr.Row():
354
  true_guidance_scale = gr.Slider(
355
+ label="Guidance Scale",
356
+ min=1.0,
357
+ max=5.0,
358
+ step=0.1,
359
+ value=1.0
360
  )
361
  num_inference_steps = gr.Slider(
362
+ label="Inference Steps",
363
+ min=4,
364
+ max=16,
365
+ step=1,
366
+ value=8
367
  )
368
+ num_images_per_prompt = gr.Slider(
369
+ label="Output Count",
370
+ min=1,
371
+ max=4,
372
+ step=1,
373
+ value=1
374
+ )
375
 
 
 
 
 
 
 
 
376
  # Output Column
377
+ with gr.Column(scale=1):
378
  result = gr.Gallery(
379
+ label="Edited Images",
380
+ columns=lambda x: min(x, 2),
381
+ height=500,
382
+ object_fit="cover",
383
+ preview=True
384
  )
385
  prompt_info = gr.HTML(
386
+ value="<div style='padding:15px; background:#f8f9fa; border-radius:8px; margin-top:15px'>"
387
+ "Prompt details will appear after generation</div>"
388
  )
389
 
390
+ # Examples
391
+ gr.Examples(
392
+ examples=[
393
+ "Change the background scene to a rooftop bar at night",
394
+ "Transform to pixel art style with 8-bit graphics",
395
+ "Replace all text with 'Qwen AI' in futuristic font"
396
+ ],
397
+ inputs=[prompt],
398
+ label="Sample Instructions",
399
+ cache_examples=True
400
+ )
401
 
402
+ # Set up processing
403
+ inputs = [
404
+ input_image,
405
+ prompt,
406
+ seed,
407
+ randomize_seed,
408
+ true_guidance_scale,
409
+ num_inference_steps,
410
+ rewrite_toggle,
411
+ num_images_per_prompt
412
+ ]
413
+
414
+ outputs = [result, seed, prompt_info]
415
+
416
+ run_button.click(
417
  fn=infer,
418
+ inputs=inputs,
419
+ outputs=outputs
 
 
 
 
 
 
 
 
 
420
  )
421
 
422
  prompt.submit(
423
  fn=infer,
424
+ inputs=inputs,
425
+ outputs=outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  )
427
 
428
  if __name__ == "__main__":