linoyts HF Staff commited on
Commit
5c5af2f
·
verified ·
1 Parent(s): f59d578

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -15
app.py CHANGED
@@ -205,37 +205,81 @@ def polish_prompt_hf(prompt, img_list):
205
  # Fallback to original prompt if enhancement fails
206
  return prompt
207
 
208
- def next_scene_prompt(img_list):
209
  """
210
  Rewrites the prompt using a Hugging Face InferenceClient.
 
211
  """
212
  # Ensure HF_TOKEN is set
213
  api_key = os.environ.get("HF_TOKEN")
214
  if not api_key:
215
  print("Warning: HF_TOKEN not set. Falling back to original prompt.")
216
- return prompt
217
-
 
218
  try:
219
  # Initialize the client
220
- prompt = f"{NEXT_SCENE_SYSTEM_PROMPT}"
221
  client = InferenceClient(
222
- provider="cerebras",
223
  api_key=api_key,
224
  )
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  # Format the messages for the chat completions API
227
- sys_promot = "you are a helpful assistant, you should provide useful answers to users."
228
  messages = [
229
- {"role": "system", "content": sys_promot},
230
- {"role": "user", "content": []}]
231
- for img in img_list:
232
- messages[1]["content"].append(
233
- {"image": f"data:image/png;base64,{encode_image(img)}"})
234
- messages[1]["content"].append({"text": f"{prompt}"})
235
 
236
  # Call the API
237
  completion = client.chat.completions.create(
238
- model="Qwen/Qwen3-235B-A22B-Instruct-2507",
239
  messages=messages,
240
  )
241
 
@@ -243,7 +287,7 @@ def next_scene_prompt(img_list):
243
  result = completion.choices[0].message.content
244
 
245
  # Try to extract JSON if present
246
- if '{"Rewritten"' in result:
247
  try:
248
  # Clean up the response
249
  result = result.replace('```json', '').replace('```', '')
@@ -260,7 +304,9 @@ def next_scene_prompt(img_list):
260
  except Exception as e:
261
  print(f"Error during API call to Hugging Face: {e}")
262
  # Fallback to original prompt if enhancement fails
263
- return prompt
 
 
264
 
265
  def encode_image(pil_image):
266
  import io
 
205
  # Fallback to original prompt if enhancement fails
206
  return prompt
207
 
208
+ def next_scene_prompt(original_prompt, img_list):
209
  """
210
  Rewrites the prompt using a Hugging Face InferenceClient.
211
+ Supports multiple images via img_list.
212
  """
213
  # Ensure HF_TOKEN is set
214
  api_key = os.environ.get("HF_TOKEN")
215
  if not api_key:
216
  print("Warning: HF_TOKEN not set. Falling back to original prompt.")
217
+ return original_prompt
218
+ prompt = f"{NEXT_SCENE_SYSTEM_PROMPT}"
219
+ system_prompt = "you are a helpful assistant, you should provide useful answers to users."
220
  try:
221
  # Initialize the client
 
222
  client = InferenceClient(
223
+ provider="nebius",
224
  api_key=api_key,
225
  )
226
 
227
+ # Convert list of images to base64 data URLs
228
+ image_urls = []
229
+ if img_list is not None:
230
+ # Ensure img_list is actually a list
231
+ if not isinstance(img_list, list):
232
+ img_list = [img_list]
233
+
234
+ for img in img_list:
235
+ image_url = None
236
+ # If img is a PIL Image
237
+ if hasattr(img, 'save'): # Check if it's a PIL Image
238
+ buffered = BytesIO()
239
+ img.save(buffered, format="PNG")
240
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
241
+ image_url = f"data:image/png;base64,{img_base64}"
242
+ # If img is already a file path (string)
243
+ elif isinstance(img, str):
244
+ with open(img, "rb") as image_file:
245
+ img_base64 = base64.b64encode(image_file.read()).decode('utf-8')
246
+ image_url = f"data:image/png;base64,{img_base64}"
247
+ else:
248
+ print(f"Warning: Unexpected image type: {type(img)}, skipping...")
249
+ continue
250
+
251
+ if image_url:
252
+ image_urls.append(image_url)
253
+
254
+ # Build the content array with text first, then all images
255
+ content = [
256
+ {
257
+ "type": "text",
258
+ "text": prompt
259
+ }
260
+ ]
261
+
262
+ # Add all images to the content
263
+ for image_url in image_urls:
264
+ content.append({
265
+ "type": "image_url",
266
+ "image_url": {
267
+ "url": image_url
268
+ }
269
+ })
270
+
271
  # Format the messages for the chat completions API
 
272
  messages = [
273
+ {"role": "system", "content": system_prompt},
274
+ {
275
+ "role": "user",
276
+ "content": content
277
+ }
278
+ ]
279
 
280
  # Call the API
281
  completion = client.chat.completions.create(
282
+ model="Qwen/Qwen2.5-VL-72B-Instruct",
283
  messages=messages,
284
  )
285
 
 
287
  result = completion.choices[0].message.content
288
 
289
  # Try to extract JSON if present
290
+ if '"Rewritten"' in result:
291
  try:
292
  # Clean up the response
293
  result = result.replace('```json', '').replace('```', '')
 
304
  except Exception as e:
305
  print(f"Error during API call to Hugging Face: {e}")
306
  # Fallback to original prompt if enhancement fails
307
+ return original_prompt
308
+
309
+
310
 
311
  def encode_image(pil_image):
312
  import io