FEREYDOONBER / app.py
FEREYDOONRAH's picture
Update app.py
9f58c97 verified
# blip2_generate_fix.py
import requests
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
MODEL = "Salesforce/blip2-opt-2.7b" # یا مدل موردنظر
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("loading processor and model (this may take a while)...")
processor = Blip2Processor.from_pretrained(MODEL, use_auth_token=None)
model = Blip2ForConditionalGeneration.from_pretrained(MODEL, use_auth_token=None)
model.to(DEVICE)
model.eval()
def generate_answer_from_image(pil_image: Image.Image, prompt_text: str = "",
max_new_tokens: int = 64, num_beams: int = 4, temperature: float = 1.0):
"""
Robust generation that avoids returning the prompt in the final output.
Returns: (generated_text, debug_info)
"""
# آماده‌سازی ورودی با پردازشگر
inputs = processor(pil_image, prompt_text, return_tensors="pt")
# تعیین طول پرامپت (تعداد توکن‌های ورودی)
input_len = 0
if "input_ids" in inputs and inputs["input_ids"] is not None:
input_len = inputs["input_ids"].shape[-1]
else:
# اگر input_ids وجود ندارد، متن پرامپت را با tokenizer رمزکن کن تا طولش را داشته باشیم
# (بعضی پردازشگرها pixel_values فقط بازمی‌گردانند)
tok = processor.tokenizer(prompt_text, return_tensors="pt")
if "input_ids" in tok and tok["input_ids"] is not None:
input_len = tok["input_ids"].shape[-1]
else:
input_len = 0
# انتقال به دستگاه
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
# تولید با جزئیات و بازگشت dict
gen_kwargs = dict(max_new_tokens=max_new_tokens, num_beams=num_beams)
# اگر temperature=0 => deterministic (do_sample=False)
if temperature is None or temperature == 0:
gen_kwargs["do_sample"] = False
else:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = float(temperature)
with torch.no_grad():
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=False, **gen_kwargs)
# گرفتن توالی کامل (ممکن است tensor یا attribute باشد)
if hasattr(outputs, "sequences"):
sequences = outputs.sequences
else:
# fallback (older transformers)
sequences = outputs
# اطمینان از انتقال به CPU برای پردازش tokenizer
sequences_cpu = sequences.cpu()
# slice جدید: فقط توکن‌های جدید (اگر طول خروجی بیشتر از طول پرامپت بود)
try:
total_len = sequences_cpu.shape[-1]
except Exception:
total_len = None
generated_text = ""
debug = {"input_len": input_len, "total_len": total_len}
if total_len and input_len and total_len > input_len:
# فقط بخش تولیدشده را جدا کن
gen_tokens = sequences_cpu[:, input_len:]
# batch_decode منتظر لیست یا آرایه‌ی اعداد است
token_list = gen_tokens[0].tolist()
generated_text = processor.tokenizer.decode(token_list, skip_special_tokens=True).strip()
debug["method"] = "slice_tokens"
else:
# fallback: اگر نتوانستیم برش بزنیم، کل توالی را decode کن و سپس تلاش کن پرامپت متنی را از ابتدای
# خروجی حذف کنی (پاراگرافی). این روش آخرین امید است زیرا معیار دقیق‌تری نیست،
# ولی امن‌تر از بازگرداندن کل prompt است.
full = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
debug["method"] = "decode_full"
# تلاش برای حذف متن پرامپت (اولین وقوع) — فقط اگر prompt_text غیرخالی و در متن وجود داشته باشد
if prompt_text and prompt_text.strip():
# ممکن است تفاوت‌های tokenization منجر به عدم داشتن دقیق prompt_text شود؛
# ما اولین وقوع متن پرامپت را حذف می‌کنیم اگر دقیقاً در خروجی آمده باشد.
if prompt_text.strip() in full:
generated_text = full.replace(prompt_text.strip(), "", 1).strip()
debug["removed_prompt_by_string"] = True
else:
# به عنوان آخرین راه، اگر prompt کوتاه است، سعی می‌کنیم تا نزدیک‌ترین بخش را حذف کنیم
# (ایمن عمل کن: فقط اگر طول خروجی خیلی طولانی باشد)
generated_text = full
debug["removed_prompt_by_string"] = False
else:
generated_text = full
# اگر هنوز خالی بود، خروجی کامل را بازگردان کن (ولی این دیگر نباید پرامپت باشد)
if not generated_text:
# در صورت نیاز می‌توانیم full را بازگردانیم
generated_text = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
debug["final_fallback"] = True
return generated_text, debug
# -------------------- مثال اجرای تست --------------------
if __name__ == "__main__":
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
prompt = "Describe the image in detail in fluent English."
output, dbg = generate_answer_from_image(raw_image, prompt_text=prompt, max_new_tokens=64, num_beams=4, temperature=1.0)
print("=== GENERATED ===")
print(output)
print("=== DEBUG ===")
print(dbg)