# blip2_generate_fix.py import requests from PIL import Image import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration MODEL = "Salesforce/blip2-opt-2.7b" # یا مدل موردنظر DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("loading processor and model (this may take a while)...") processor = Blip2Processor.from_pretrained(MODEL, use_auth_token=None) model = Blip2ForConditionalGeneration.from_pretrained(MODEL, use_auth_token=None) model.to(DEVICE) model.eval() def generate_answer_from_image(pil_image: Image.Image, prompt_text: str = "", max_new_tokens: int = 64, num_beams: int = 4, temperature: float = 1.0): """ Robust generation that avoids returning the prompt in the final output. Returns: (generated_text, debug_info) """ # آماده‌سازی ورودی با پردازشگر inputs = processor(pil_image, prompt_text, return_tensors="pt") # تعیین طول پرامپت (تعداد توکن‌های ورودی) input_len = 0 if "input_ids" in inputs and inputs["input_ids"] is not None: input_len = inputs["input_ids"].shape[-1] else: # اگر input_ids وجود ندارد، متن پرامپت را با tokenizer رمزکن کن تا طولش را داشته باشیم # (بعضی پردازشگرها pixel_values فقط بازمی‌گردانند) tok = processor.tokenizer(prompt_text, return_tensors="pt") if "input_ids" in tok and tok["input_ids"] is not None: input_len = tok["input_ids"].shape[-1] else: input_len = 0 # انتقال به دستگاه for k, v in inputs.items(): inputs[k] = v.to(DEVICE) # تولید با جزئیات و بازگشت dict gen_kwargs = dict(max_new_tokens=max_new_tokens, num_beams=num_beams) # اگر temperature=0 => deterministic (do_sample=False) if temperature is None or temperature == 0: gen_kwargs["do_sample"] = False else: gen_kwargs["do_sample"] = True gen_kwargs["temperature"] = float(temperature) with torch.no_grad(): outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=False, **gen_kwargs) # گرفتن توالی کامل (ممکن است tensor یا attribute باشد) if hasattr(outputs, "sequences"): sequences = outputs.sequences else: # fallback (older transformers) sequences = outputs # اطمینان از انتقال به CPU برای پردازش tokenizer sequences_cpu = sequences.cpu() # slice جدید: فقط توکن‌های جدید (اگر طول خروجی بیشتر از طول پرامپت بود) try: total_len = sequences_cpu.shape[-1] except Exception: total_len = None generated_text = "" debug = {"input_len": input_len, "total_len": total_len} if total_len and input_len and total_len > input_len: # فقط بخش تولیدشده را جدا کن gen_tokens = sequences_cpu[:, input_len:] # batch_decode منتظر لیست یا آرایه‌ی اعداد است token_list = gen_tokens[0].tolist() generated_text = processor.tokenizer.decode(token_list, skip_special_tokens=True).strip() debug["method"] = "slice_tokens" else: # fallback: اگر نتوانستیم برش بزنیم، کل توالی را decode کن و سپس تلاش کن پرامپت متنی را از ابتدای # خروجی حذف کنی (پاراگرافی). این روش آخرین امید است زیرا معیار دقیق‌تری نیست، # ولی امن‌تر از بازگرداندن کل prompt است. full = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip() debug["method"] = "decode_full" # تلاش برای حذف متن پرامپت (اولین وقوع) — فقط اگر prompt_text غیرخالی و در متن وجود داشته باشد if prompt_text and prompt_text.strip(): # ممکن است تفاوت‌های tokenization منجر به عدم داشتن دقیق prompt_text شود؛ # ما اولین وقوع متن پرامپت را حذف می‌کنیم اگر دقیقاً در خروجی آمده باشد. if prompt_text.strip() in full: generated_text = full.replace(prompt_text.strip(), "", 1).strip() debug["removed_prompt_by_string"] = True else: # به عنوان آخرین راه، اگر prompt کوتاه است، سعی می‌کنیم تا نزدیک‌ترین بخش را حذف کنیم # (ایمن عمل کن: فقط اگر طول خروجی خیلی طولانی باشد) generated_text = full debug["removed_prompt_by_string"] = False else: generated_text = full # اگر هنوز خالی بود، خروجی کامل را بازگردان کن (ولی این دیگر نباید پرامپت باشد) if not generated_text: # در صورت نیاز می‌توانیم full را بازگردانیم generated_text = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip() debug["final_fallback"] = True return generated_text, debug # -------------------- مثال اجرای تست -------------------- if __name__ == "__main__": img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') prompt = "Describe the image in detail in fluent English." output, dbg = generate_answer_from_image(raw_image, prompt_text=prompt, max_new_tokens=64, num_beams=4, temperature=1.0) print("=== GENERATED ===") print(output) print("=== DEBUG ===") print(dbg)