LPX55 commited on
Commit
3a64b85
·
verified ·
1 Parent(s): 7585b3f

Update app_local.py

Browse files
Files changed (1) hide show
  1. app_local.py +28 -32
app_local.py CHANGED
@@ -18,11 +18,13 @@ os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
18
 
19
  # Model configuration
20
  REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling
21
- rewriter_tokenizer = None
22
- rewriter_model = None
23
  dtype = torch.bfloat16
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
 
 
 
 
26
  # Quantization configuration
27
  bnb_config = BitsAndBytesConfig(
28
  load_in_4bit=True,
@@ -31,19 +33,14 @@ bnb_config = BitsAndBytesConfig(
31
  bnb_4bit_use_double_quant=True
32
  )
33
 
34
- def load_rewriter():
35
- """Lazily load the prompt enhancement model"""
36
- global rewriter_tokenizer, rewriter_model
37
- if rewriter_tokenizer is None or rewriter_model is None:
38
- print("🔄 Loading enhancement model...")
39
- rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
40
- rewriter_model = AutoModelForCausalLM.from_pretrained(
41
- REWRITER_MODEL,
42
- torch_dtype=dtype,
43
- device_map="auto",
44
- quantization_config=bnb_config
45
- )
46
- print("✅ Enhancement model loaded")
47
 
48
  SYSTEM_PROMPT_EDIT = '''
49
  # Edit Instruction Rewriter
@@ -103,7 +100,7 @@ def extract_json_response(model_output: str) -> str:
103
 
104
  # Extract rewritten prompt from possible key variations
105
  possible_keys = [
106
- "Rewritten", "rewritten", "Rewrited", "rewrited",
107
  "Output", "output", "Enhanced", "enhanced"
108
  ]
109
  for key in possible_keys:
@@ -151,9 +148,9 @@ def polish_prompt(original_prompt: str) -> str:
151
  with torch.no_grad():
152
  generated_ids = rewriter_model.generate(
153
  **model_inputs,
154
- max_new_tokens=150, # Reduced for better quality
155
  do_sample=True,
156
- temperature=0.4, # Less creative but more focused
157
  top_p=0.9,
158
  no_repeat_ngram_size=3,
159
  pad_token_id=rewriter_tokenizer.eos_token_id
@@ -218,15 +215,15 @@ if is_xformers_available():
218
  else:
219
  print("xformers not available")
220
 
221
- def unload_rewriter():
222
- """Clear enhancement model from memory"""
223
- global rewriter_tokenizer, rewriter_model
224
- if rewriter_model:
225
- del rewriter_tokenizer, rewriter_model
226
- rewriter_tokenizer = None
227
- rewriter_model = None
228
- torch.cuda.empty_cache()
229
- gc.collect()
230
 
231
  @spaces.GPU(duration=60)
232
  def infer(
@@ -271,9 +268,7 @@ def infer(
271
  f"</div>"
272
  )
273
 
274
- # Free VRAM after enhancement
275
- unload_rewriter()
276
-
277
  # Set seed for reproducibility
278
  seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
279
  generator = torch.Generator(device=device).manual_seed(seed_val)
@@ -302,11 +297,12 @@ def infer(
302
 
303
  MAX_SEED = np.iinfo(np.int32).max
304
 
305
- with gr.Blocks(title="Qwen Image Editor Fast", css=".gr-gallery {min-height: 300px}") as demo:
306
  gr.Markdown("""
307
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
308
  <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1>
309
- <p>8-step inferencing Local Prompt Enhancement • H200 Optimized</p>
 
310
  </div>
311
  """)
312
 
 
18
 
19
  # Model configuration
20
  REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling
 
 
21
  dtype = torch.bfloat16
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ # Preload enhancement model at startup
25
+ print("🔄 Loading prompt enhancement model...")
26
+ rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
27
+
28
  # Quantization configuration
29
  bnb_config = BitsAndBytesConfig(
30
  load_in_4bit=True,
 
33
  bnb_4bit_use_double_quant=True
34
  )
35
 
36
+ rewriter_model = AutoModelForCausalLM.from_pretrained(
37
+ REWRITER_MODEL,
38
+ torch_dtype=dtype,
39
+ device_map="auto",
40
+ quantization_config=bnb_config,
41
+ max_memory={0: "48GiB"}, # Reserve adequate memory
42
+ )
43
+ print("✅ Enhancement model loaded and ready!")
 
 
 
 
 
44
 
45
  SYSTEM_PROMPT_EDIT = '''
46
  # Edit Instruction Rewriter
 
100
 
101
  # Extract rewritten prompt from possible key variations
102
  possible_keys = [
103
+ "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent",
104
  "Output", "output", "Enhanced", "enhanced"
105
  ]
106
  for key in possible_keys:
 
148
  with torch.no_grad():
149
  generated_ids = rewriter_model.generate(
150
  **model_inputs,
151
+ max_new_tokens=256, # Reduced for better quality
152
  do_sample=True,
153
+ temperature=0.5, # Less creative but more focused
154
  top_p=0.9,
155
  no_repeat_ngram_size=3,
156
  pad_token_id=rewriter_tokenizer.eos_token_id
 
215
  else:
216
  print("xformers not available")
217
 
218
+ # def unload_rewriter():
219
+ # """Clear enhancement model from memory"""
220
+ # global rewriter_tokenizer, rewriter_model
221
+ # if rewriter_model:
222
+ # del rewriter_tokenizer, rewriter_model
223
+ # rewriter_tokenizer = None
224
+ # rewriter_model = None
225
+ # torch.cuda.empty_cache()
226
+ # gc.collect()
227
 
228
  @spaces.GPU(duration=60)
229
  def infer(
 
268
  f"</div>"
269
  )
270
 
271
+
 
 
272
  # Set seed for reproducibility
273
  seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
274
  generator = torch.Generator(device=device).manual_seed(seed_val)
 
297
 
298
  MAX_SEED = np.iinfo(np.int32).max
299
 
300
+ with gr.Blocks(title="Qwen Image Editor Fast") as demo:
301
  gr.Markdown("""
302
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
303
  <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1>
304
+ <p>✨ 8-step inferencing with lightx2v's LoRA.")
305
+ <p>📝 Local Prompt Enhancement</p>
306
  </div>
307
  """)
308