zhangjiewu commited on
Commit
6d75145
·
1 Parent(s): 3d1918c

update prompt enhancer

Browse files
Files changed (2) hide show
  1. app.py +4 -24
  2. prompt_enhancer.py +14 -4
app.py CHANGED
@@ -82,11 +82,12 @@ if lora_path:
82
  )
83
  print(f"✓ Configured scheduler (flow_shift=2.0)")
84
 
 
85
  end = time.time()
86
  print(f"Model loaded in {end - start:.2f}s.")
87
 
88
  start = time.time()
89
- prompt_enhancer_model = "Qwen/Qwen3-VL-30B-A3B-Instruct"
90
  prompt_model, processor = load_model(prompt_enhancer_model)
91
  end = time.time()
92
  print(f"Prompt enhancer loaded in {end - start:.2f}s.")
@@ -124,12 +125,6 @@ def run_inference(
124
  shift: float = 2.0,
125
  num_temporal_reasoning_steps: int = 8,
126
  ):
127
-
128
- start = time.time()
129
- prompt_model.to(DEVICE)
130
- end = time.time()
131
- print(f"Pipeline moved to device in {end - start:.2f}s")
132
-
133
  # Rewriter
134
  final_prompt = prompt
135
 
@@ -152,11 +147,6 @@ def run_inference(
152
  print("=" * 80 + "\n")
153
  final_prompt = cot_prompt
154
 
155
- start = time.time()
156
- prompt_model.to("cpu")
157
- end = time.time()
158
- print(f"Pipeline moved to cpu in {end - start:.2f}s")
159
-
160
  # Inference
161
  print(f"Loading input image: {image_path}")
162
  image = load_image(image_path)
@@ -169,11 +159,6 @@ def run_inference(
169
  image = image.resize((width, height))
170
  num_frames = 29 if enable_temporal_reasoning else 5
171
 
172
- start = time.time()
173
- pipe.to(DEVICE)
174
- end = time.time()
175
- print(f"Pipeline moved to device in {end - start:.2f}s")
176
-
177
  start = time.time()
178
  output = pipe(
179
  image=image,
@@ -189,11 +174,6 @@ def run_inference(
189
  end = time.time()
190
  print(f"Generated video in {end - start:.2f}s")
191
 
192
- start = time.time()
193
- pipe.to("cpu")
194
- end = time.time()
195
- print(f"Pipeline moved to cpu in {end - start:.2f}s")
196
-
197
  video_tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
198
  output_path_video = video_tmp.name
199
  video_tmp.close()
@@ -268,7 +248,7 @@ def build_ui() -> gr.Blocks:
268
  ],
269
  [
270
  "examples/2.png",
271
- "The user wants to change the scene so that the girl in the traditional-style painting, wearing her ornate floral robe and headdress, is now playing a guitar.",
272
  False,
273
  ],
274
  [
@@ -291,4 +271,4 @@ def build_ui() -> gr.Blocks:
291
  if __name__ == "__main__":
292
  demo = build_ui()
293
  # demo.launch(server_name="0.0.0.0", server_port=7869)
294
- demo.queue().launch()
 
82
  )
83
  print(f"✓ Configured scheduler (flow_shift=2.0)")
84
 
85
+ pipe.to(DEVICE)
86
  end = time.time()
87
  print(f"Model loaded in {end - start:.2f}s.")
88
 
89
  start = time.time()
90
+ prompt_enhancer_model = "Qwen/Qwen3-VL-8B-Instruct"
91
  prompt_model, processor = load_model(prompt_enhancer_model)
92
  end = time.time()
93
  print(f"Prompt enhancer loaded in {end - start:.2f}s.")
 
125
  shift: float = 2.0,
126
  num_temporal_reasoning_steps: int = 8,
127
  ):
 
 
 
 
 
 
128
  # Rewriter
129
  final_prompt = prompt
130
 
 
147
  print("=" * 80 + "\n")
148
  final_prompt = cot_prompt
149
 
 
 
 
 
 
150
  # Inference
151
  print(f"Loading input image: {image_path}")
152
  image = load_image(image_path)
 
159
  image = image.resize((width, height))
160
  num_frames = 29 if enable_temporal_reasoning else 5
161
 
 
 
 
 
 
162
  start = time.time()
163
  output = pipe(
164
  image=image,
 
174
  end = time.time()
175
  print(f"Generated video in {end - start:.2f}s")
176
 
 
 
 
 
 
177
  video_tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
178
  output_path_video = video_tmp.name
179
  video_tmp.close()
 
248
  ],
249
  [
250
  "examples/2.png",
251
+ "The user wants to change the scene so that the girl in the traditional-style painting, wearing her ornate floral robe and headdress, is now playing a guitar. Her graceful appearance remains unchanged - smooth black hair tied neatly, soft facial features with a calm, focused expression - but her pose shifts: both hands are engaged with the guitar. One hand rests on the neck of the instrument, fingers pressing the strings with delicate precision, while the other hand strums near the sound hole. The guitar is positioned naturally across her lap, blending with the elegance of her posture. The traditional painting style is preserved, but the addition of the guitar introduces a modern contrast, giving the scene a harmonious fusion of classical refinement and contemporary music.",
252
  False,
253
  ],
254
  [
 
271
  if __name__ == "__main__":
272
  demo = build_ui()
273
  # demo.launch(server_name="0.0.0.0", server_port=7869)
274
+ demo.queue().launch(share=True)
prompt_enhancer.py CHANGED
@@ -17,8 +17,9 @@ import argparse
17
  import torch
18
  from PIL import Image
19
  from transformers import (
20
- Qwen2_5_VLForConditionalGeneration,
21
- AutoProcessor,
 
22
  Qwen3VLMoeForConditionalGeneration,
23
  )
24
  from qwen_vl_utils import process_vision_info
@@ -108,13 +109,22 @@ def load_model(model_name):
108
  )
109
  processor = AutoProcessor.from_pretrained(model_name)
110
 
 
 
 
 
 
 
 
 
 
111
  else:
112
  raise ValueError(f"Unsupported model: {model_name}")
113
 
114
  return model, processor
115
 
116
 
117
- def resize_if_needed(image, max_resolution=1080):
118
  """Resize image so that the shortest edge is at most max_resolution pixels."""
119
  width, height = image.size
120
  if min(width, height) > max_resolution:
@@ -152,7 +162,7 @@ def _run_model_inference(messages, model, processor):
152
  inputs = inputs.to(model.device).to(model.dtype)
153
  generated_ids = model.generate(**inputs, max_new_tokens=512)
154
 
155
- elif isinstance(model, Qwen3VLMoeForConditionalGeneration):
156
  inputs = processor.apply_chat_template(
157
  messages,
158
  tokenize=True,
 
17
  import torch
18
  from PIL import Image
19
  from transformers import (
20
+ Qwen2_5_VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ Qwen3VLForConditionalGeneration,
23
  Qwen3VLMoeForConditionalGeneration,
24
  )
25
  from qwen_vl_utils import process_vision_info
 
109
  )
110
  processor = AutoProcessor.from_pretrained(model_name)
111
 
112
+ elif model_name == "Qwen/Qwen3-VL-8B-Instruct":
113
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
114
+ model_name,
115
+ dtype=torch.bfloat16,
116
+ attn_implementation=attn_impl,
117
+ device_map="auto"
118
+ )
119
+ processor = AutoProcessor.from_pretrained(model_name)
120
+
121
  else:
122
  raise ValueError(f"Unsupported model: {model_name}")
123
 
124
  return model, processor
125
 
126
 
127
+ def resize_if_needed(image, max_resolution):
128
  """Resize image so that the shortest edge is at most max_resolution pixels."""
129
  width, height = image.size
130
  if min(width, height) > max_resolution:
 
162
  inputs = inputs.to(model.device).to(model.dtype)
163
  generated_ids = model.generate(**inputs, max_new_tokens=512)
164
 
165
+ elif isinstance(model, Qwen3VLMoeForConditionalGeneration) or isinstance(model, Qwen3VLForConditionalGeneration):
166
  inputs = processor.apply_chat_template(
167
  messages,
168
  tokenize=True,