Pixcribe / caption_generation_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from typing import List, Dict
import json
from opencc import OpenCC
import warnings
class CaptionGenerationManager:
"""Caption generation using Vision-Language Models (supports Qwen2.5-VL, Qwen3-VL, etc.)"""
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"):
"""
Args:
model_name: Vision-Language model name, e.g.:
- "Qwen/Qwen2.5-VL-7B-Instruct" (default)
- "Qwen/Qwen3-VL-8B-Instruct" (2025 latest)
"""
print(f"Loading Vision-Language Model: {model_name}...")
# Suppress processor warning
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
# Use Auto* classes for flexibility (supports Qwen2.5-VL, Qwen3-VL, etc.)
self.processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
dtype=torch.bfloat16, # Changed from torch_dtype to dtype
device_map="auto"
)
# Simplified Chinese to Traditional Chinese converter
self.cc = OpenCC('s2t') # Simplified to Traditional
self.generation_config = {
'temperature': 0.7,
'top_p': 0.9,
'max_new_tokens': 300, # Increased from 200 to prevent truncation
'repetition_penalty': 1.1
}
# Platform-specific templates
self.platform_templates = {
'instagram': {
'style': 'storytelling, aesthetic',
'emoji_count': '2-3',
'hashtag_count': '8-10',
'min_length': 120, # Increased for richer content
'max_length': 220, # Allow more detailed descriptions
'features': ['call-to-action', 'question', 'relatable']
},
'tiktok': {
'style': 'brief, punchy',
'emoji_count': '1-2',
'hashtag_count': '5-8',
'min_length': 60,
'max_length': 120,
'features': ['trending', 'POV', 'relatable']
},
'xiaohongshu': {
'style': 'structured, informative, detailed',
'emoji_count': '5-8',
'hashtag_count': '8-12',
'min_length': 180,
'max_length': 500,
'features': ['tips', 'bullets', 'sharing-tone']
}
}
print(f"✓ {model_name.split('/')[-1]} loaded successfully (using Auto* classes for flexibility)")
def construct_prompt(self, analysis_results: Dict, platform: str = 'instagram', language: str = 'zh') -> str:
"""Construct prompt with language support ensuring consistency
Args:
language: 'zh' (Traditional Chinese), 'en' (English), 'zh-en' (Bilingual)
"""
platform_config = self.platform_templates.get(platform, self.platform_templates['instagram'])
# Language-specific instructions
language_instructions = {
'zh': '請使用繁體中文生成標題和標籤。語言要自然流暢,符合華語社群媒體的表達習慣。避免使用簡體字。當偵測到品牌時,必須在標題中提及品牌名稱。',
'en': '''🚨 CRITICAL LANGUAGE REQUIREMENT 🚨
Generate captions and hashtags EXCLUSIVELY in English.
- NEVER use Chinese characters (Traditional or Simplified)
- NEVER mix languages
- Use natural, engaging language suitable for international social media
- When brands are detected, mention them naturally in English
- All text output must be 100% English only
This is MANDATORY and NON-NEGOTIABLE.''',
'zh-en': '''生成雙語內容:標題使用繁體中文,同時提供英文翻譯。標籤混合使用中英文以擴大觸及範圍。當偵測到品牌時,必須在標題中提及品牌名稱。
🚨 重要:雙語一致性要求 🚨
- 中文和英文必須表達相同的核心意義
- 允許表達方式的差異(形容詞、語法不同)
- 但整體訊息、語氣、品牌提及必須一致
- 兩種語言都要朝同一方向詮釋內容'''
}
system_instruction = f"""You are a professional social media content strategist.
{language_instructions.get(language, language_instructions['zh'])}
Target platform: {platform}
Content style: Authentic, creative, and optimized for engagement.
CRITICAL RULE: Never include hashtags (symbols starting with #) in the caption text. Hashtags must only appear in the separate 'hashtags' array."""
# Extract analysis context
objects = analysis_results.get('detections', [])
brands = analysis_results.get('brands', [])
scene_info = analysis_results.get('scene_analysis', {})
composition = analysis_results.get('composition', {})
# FIXED: Get fused lighting from scene_info (it's been updated by DetectionFusionManager)
lighting = scene_info.get('lighting', {}).get('top', 'natural light')
lighting_confidence = scene_info.get('lighting', {}).get('confidence', 0.7)
# Provide explicit Chinese translations to ensure consistency
lighting_translations_zh = {
'soft diffused light': '柔和漫射光',
'overcast atmosphere': '陰天氛圍',
'natural daylight': '自然日光',
'warm ambient light': '溫暖環境光',
'evening light': '傍晚光線',
'bright sunlight': '明亮陽光',
'golden hour': '金黃時刻',
'blue hour': '藍調時刻'
}
# Get appropriate lighting description based on language
if language == 'zh':
lighting_zh = lighting_translations_zh.get(lighting, lighting)
lighting_display = lighting_zh
else:
# For English and bilingual, use English only
lighting_display = lighting
lighting_zh = lighting
objects_str = ', '.join([obj['class_name'] for obj in objects[:10]])
# CRITICAL: Emphasize brands EXTREMELY prominently - repeat multiple times
if brands:
brands_list = [b[0] for b in brands[:5]]
brands_str = ', '.join(brands_list)
brand_emphasis = f"""
🚨 CRITICAL BRAND REQUIREMENT 🚨
The following brands were POSITIVELY IDENTIFIED in this image: {brands_str}
YOU ABSOLUTELY MUST:
1. Mention the brand name "{brands_list[0]}" explicitly in the FIRST sentence
2. Use the exact brand name - do not use generic terms like "bag" or "accessory" without the brand
3. Write naturally as if you're excited to share this {brands_list[0]} item
4. Example: "在傍晚光線下,這款{brands_list[0]}經典黑色菱格紋皮革包..." (CORRECT)
5. NOT acceptable: "在傍晚光線下,這款經典黑色菱格紋皮革包..." (WRONG - missing brand name!)
THIS IS MANDATORY - The caption will be rejected if it doesn't mention {brands_str}.
"""
else:
brands_str = 'None detected'
brand_emphasis = ""
# Enhanced scene description
urban_scene = scene_info.get('urban', {}).get('top', 'unknown')
mood = scene_info.get('mood', {}).get('top', 'neutral')
comp_type = composition.get('composition_type', 'standard')
context = f"""
Analyze this image and generate an engaging, DETAILED social media caption with rich visual descriptions.
**Visual Elements (Describe in Detail):**
- Detected objects: {objects_str}
- Scene composition: {comp_type}
- Urban environment: {urban_scene}
- **IMPORTANT**: Include specific details about:
* Materials (leather, metal, fabric, canvas, etc.)
* Colors (use descriptive terms: jet black, antique gold, midnight blue, etc.)
* Textures (quilted, smooth, matte, glossy, metallic, etc.)
* Design features (stitching patterns, hardware, logos, emblems, etc.)
* Reflections and lighting effects on surfaces
**Atmosphere:**
- Lighting (analyzed with Places365 + CV): {lighting_display} (confidence: {lighting_confidence:.2f})
- Mood: {mood}
**Brand Detection:**
- Identified brands: {brands_str}{brand_emphasis}
**Caption Structure (Required - BE SPECIFIC AND DETAILED):**
1. Opening hook - Most striking visual element with SPECIFIC details (1-2 sentences)
{f"- 🚨 MANDATORY: Start with the BRAND NAME '{brands_list[0]}' in the FIRST sentence!" if brands else ""}
{f"- Example (CORRECT): '這款{brands_list[0]}經典黑色菱格紋皮革包...'" if brands else ""}
{f"- Example (WRONG): '這款經典黑色菱格紋皮革包...' (missing {brands_list[0]}!)" if brands else ""}
- Be SPECIFIC: Include material, color, design features WITH the brand name
2. Visual details - Describe materials, textures, colors, and design elements (2-3 sentences)
- Be SPECIFIC: mention quilting patterns, metal finishes, chain details, logo placements
- Describe how light interacts with materials (reflections on leather, gleam of metal)
- MUST use the EXACT lighting description: "{lighting_display}"
3. Atmospheric context - How lighting and mood create the scene's character (1-2 sentences)
- Connect lighting to the overall visual impact
- Describe depth, shadows, contrasts
4. Emotional connection & Engagement - How this resonates with viewers + call-to-action (1 sentence)
**Content Requirements:**
- Minimum information: 3-4 specific visual details per caption
- Include material types, color descriptions, design features
- Describe how lighting affects the appearance
- Make it vivid and immersive
Platform style: {platform_config['style']}
"""
# Language-specific examples with DETAILED visual descriptions AND BRAND NAMES
if language == 'zh':
brand_name_zh = brands_list[0] if brands else "Gucci" # Use detected brand or example
example_correct = f"""正確範例 - 詳細描述 + 品牌提及 (繁體中文):
"在{lighting_zh}的映襯下,這款{brand_name_zh}經典黑色菱格紋皮革包展現奢華質感,V字形縫線在柔軟小牛皮上勾勒出精緻的幾何圖案,復古金色雙G標誌在深色背景中熠熠生輝。金屬鏈條肩帶反射著{lighting_zh},增添層次感與立體效果。皮革表面細膩的光澤與霧面質地形成迷人對比,每個細節都彰顯義大利工藝的極致追求。這樣的{brand_name_zh}單品不只是配件,更是品味與格調的完美詮釋。你的衣櫃裡有哪件經典單品?✨🖤"
注意:品牌名稱 "{brand_name_zh}" 出現在第一句!這是正確的做法。
CRITICAL:
- 必須包含材質描述(皮革、金屬等)
- 必須包含顏色細節(黑色、復古金色等)
- 必須包含設計特點(縫線、標誌、鏈條等)
- 必須使用"{lighting_zh}"來描述光線
"""
elif language == 'en':
brand_name_en = brands_list[0] if brands else "Gucci" # Use detected brand or example
example_correct = f"""CORRECT EXAMPLE - Detailed Description + Brand Mention (ENGLISH ONLY - NO CHINESE):
"Under the {lighting}, this {brand_name_en} classic black quilted leather bag showcases luxurious craftsmanship. V-shaped stitching traces intricate geometric patterns across supple calfskin, while the antique gold double-G logo gleams against the dark backdrop. The metal chain strap catches and reflects the {lighting}, adding dimension and depth to the piece. The leather surface presents a captivating contrast between fine sheen and matte texture, with every detail exemplifying Italian artisanship at its finest. This {brand_name_en} piece isn't just an accessory – it's a perfect expression of taste and sophistication. What's your timeless wardrobe essential? ✨🖤"
NOTE: Brand name "{brand_name_en}" appears in the FIRST sentence! This is the correct approach.
🚨 ABSOLUTE REQUIREMENT FOR ENGLISH MODE 🚨
- Output must be 100% ENGLISH - zero Chinese characters allowed
- MUST include material descriptions (leather, metal, etc.)
- MUST include color details (black, antique gold, etc.)
- MUST include design features (stitching, logo, chain, etc.)
- MUST use "{lighting}" to describe the lighting
- NO Chinese characters anywhere in the output
"""
else: # zh-en bilingual
brand_name_en = brands_list[0] if brands else "Gucci"
example_correct = f"""BILINGUAL EXAMPLE - 雙語範例:
Caption in Traditional Chinese, with English hashtags support.
(Details omitted for brevity)
"""
# Language-specific hashtag instructions
if language == 'zh':
hashtag_instruction = """
【CRITICAL HASHTAG REQUIREMENT - 繁體中文】:
- ALL hashtags MUST be in Traditional Chinese (繁體中文)
- NEVER use English hashtags when language is 繁體中文
- Examples of CORRECT hashtags: ["時尚包包", "奢華風格", "皮革工藝", "精品配件"]
- Examples of WRONG hashtags: ["FashionBlogger", "LuxuryLifestyle"] - DO NOT USE THESE
"""
elif language == 'en':
hashtag_instruction = """
【CRITICAL HASHTAG REQUIREMENT - English】:
- ALL hashtags MUST be in English
- NEVER use Chinese characters in hashtags
- Examples of CORRECT hashtags: ["FashionBlogger", "LuxuryLifestyle", "LeatherCraft"]
"""
else: # zh-en
hashtag_instruction = """
【CRITICAL HASHTAG REQUIREMENT - Bilingual】:
- Hashtags should MIX Traditional Chinese and English
- First half in Chinese, second half in English
- Example: ["時尚包包", "奢華風格", "FashionBlogger", "LuxuryLifestyle"]
"""
output_format = f"""
Generate output in JSON format:
{{
"caption": "string (minimum {platform_config['min_length']} chars, maximum {platform_config['max_length']} chars, engaging and descriptive)",
"hashtags": ["tag1", "tag2", ...] ({platform_config['hashtag_count']} relevant hashtags),
"tone": "casual|professional|playful",
"platform": "{platform}"
}}
{hashtag_instruction}
STRICT REQUIREMENTS:
1. Caption length: {platform_config['min_length']}-{platform_config['max_length']} characters
2. 🚨 EMOJI REQUIREMENT 🚨 - MUST use EXACTLY {platform_config['emoji_count']} emojis naturally integrated into caption text
- Professional style: 1-2 emojis (e.g., ✨💼🌟)
- Creative style: 2-3 emojis (e.g., 🎨✨💫🌙)
- Authentic style: 2-3 emojis (e.g., 💖👜✨🖤)
- Place emojis naturally within or at end of sentences
3. Caption must be pure descriptive text only - absolutely NO hashtags allowed
4. 🚨 CALL-TO-ACTION REQUIREMENT 🚨 - MUST include an engaging question or CTA at the end
- Professional: Brief professional question (e.g., "What's your go-to piece?")
- Creative: Thought-provoking question (e.g., "How does this speak to you?")
- Authentic: Personal question (e.g., "What's your favorite timeless accessory?")
5. Write 3-4 complete sentences following the structure above
6. Be specific and vivid - describe what you see in detail
7. 【CRITICAL】 MUST use the EXACT lighting description: "{lighting_display}"
- DO NOT substitute with similar terms
- DO NOT use "金黃時刻" if the lighting is "{lighting_zh if language == 'zh' else lighting}"
- DO NOT invent your own lighting description
8. 🚨 HASHTAG REQUIREMENT 🚨 - Generate {platform_config['hashtag_count']} relevant hashtags
- Hashtags go ONLY in the 'hashtags' array, NEVER in the caption text
- Mix of broad and specific tags
- Include brand name as hashtag if detected
9. {"🚨 CRITICAL BRAND REQUIREMENT 🚨 - The brand name '" + brands_list[0] + "' MUST appear in the FIRST sentence of your caption. This is MANDATORY and NON-NEGOTIABLE. Example: " + ("'這款" + brands_list[0] + "經典黑色...'" if language == 'zh' else "'This " + brands_list[0] + " classic black...'") if brands else "No brands detected to mention"}
10. {"🚨 LANGUAGE REQUIREMENT 🚨 - Output must be 100% ENGLISH ONLY. NO Chinese characters allowed anywhere." if language == 'en' else ""}
WRONG EXAMPLE (DO NOT DO THIS):
"Lost in the city's towering skyscrapers 🏙️✨ | #UrbanVibes #CityLife"
{example_correct}
"""
full_prompt = f"{system_instruction}\n\n{context}\n\n{output_format}"
return full_prompt
def generate_captions(self, analysis_results: Dict, image: Image.Image,
platform: str = 'instagram', language: str = 'zh') -> List[Dict]:
"""Generate 3 captions with distinct styles: Professional, Creative, Authentic"""
# Extract brands for style instructions
brands_in_image = analysis_results.get('brands', [])
brand_names = [b[0] for b in brands_in_image[:3]] if brands_in_image else []
brand_mention_requirement = f" CRITICAL: Mention {', '.join(brand_names)} brand(s) naturally in the caption." if brand_names else ""
# Define 3 distinct styles
styles = [
{
'name': 'professional',
'temp': 0.6,
'instruction': f'Professional style: Concise, elegant, sophisticated. Focus on quality and craftsmanship. Use refined language.{brand_mention_requirement}',
'length_modifier': 0.8 # Shorter, more concise
},
{
'name': 'creative',
'temp': 0.7,
'instruction': f'Creative style: Artistic, expressive, imaginative. Use vivid metaphors and sensory descriptions. Balance detail with flair.{brand_mention_requirement}',
'length_modifier': 1.0 # Medium length
},
{
'name': 'authentic',
'temp': 0.8,
'instruction': f'Authentic style: Personal, detailed, storytelling. Share rich observations and genuine feelings. Most descriptive and engaging.{brand_mention_requirement}',
'length_modifier': 1.2 # Longer, more detailed
}
]
variations = []
for style in styles:
# Build style-specific prompt
base_prompt = self.construct_prompt(analysis_results, platform, language)
# Add style instruction
style_prompt = f"""{base_prompt}
**STYLE REQUIREMENT FOR THIS CAPTION:**
{style['instruction']}
Adjust tone to be clearly '{style['name']}' - this should be noticeably different from other styles."""
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": style_prompt}
]
}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = inputs.to("cuda")
# Generate with style-specific temperature
config = self.generation_config.copy()
config['temperature'] = style['temp']
with torch.no_grad():
generated_ids = self.model.generate(**inputs, **config)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
parsed = self._parse_json_output(output_text)
if parsed:
# Force the correct tone
parsed['tone'] = style['name']
# Remove any hashtags that leaked into caption
if 'caption' in parsed:
parsed['caption'] = self._remove_hashtags_from_caption(parsed['caption'])
# Convert Simplified Chinese to Traditional if language is 'zh'
if language == 'zh' or language == 'zh-en':
parsed = self._convert_to_traditional(parsed)
variations.append(parsed)
return variations if variations else [self._get_fallback_caption(platform, language)]
def _remove_hashtags_from_caption(self, caption: str) -> str:
"""Remove any hashtags, pipes, and debug info that leaked into caption text"""
import re
# CRITICAL FIX: Remove pipe symbol and everything after it (debug info)
# Example: "Text 🕰️🌉 | SoftDiffusedLight" -> "Text 🕰️🌉"
if '|' in caption:
caption = caption.split('|')[0].strip()
# Remove hashtags (words starting with #)
caption = re.sub(r'#\w+', '', caption)
caption = re.sub(r'#[\u4e00-\u9fff]+', '', caption) # Remove Chinese hashtags
# Remove standalone weird text patterns (like "BLACKBELT")
# If there's a suspicious all-caps word at the end without context, remove it
words = caption.split()
if len(words) > 0:
last_word = words[-1].strip('✨💎👗🌟💫🖤')
# If last "word" is all caps and doesn't look like a normal sentence word, remove it
if last_word.isupper() and len(last_word) > 3 and not any(char in last_word for char in '.,!?'):
caption = ' '.join(words[:-1])
# Remove excessive emojis at the end (more than 3)
emoji_pattern = r'[\U0001F300-\U0001F9FF]{4,}$'
caption = re.sub(emoji_pattern, '', caption)
# Remove multiple spaces
caption = re.sub(r'\s+', ' ', caption)
# Remove trailing/leading whitespace
caption = caption.strip()
# Final cleanup: if caption ends with weird patterns like "✨X 👗💎", clean it
if re.search(r'[✨💎👗🌟💫🖤]{2,}\s*$', caption):
caption = re.sub(r'[✨💎👗🌟💫🖤\s]+$', '', caption).strip()
return caption
def _convert_to_traditional(self, caption: Dict) -> Dict:
"""Convert Simplified Chinese to Traditional Chinese"""
if 'caption' in caption:
caption['caption'] = self.cc.convert(caption['caption'])
return caption
def _parse_json_output(self, text: str) -> Dict:
"""Parse JSON output"""
try:
start = text.find('{')
end = text.rfind('}') + 1
if start != -1 and end > start:
json_str = text[start:end]
return json.loads(json_str)
except:
pass
return None
def _get_fallback_caption(self, platform: str, language: str) -> Dict:
"""Fallback caption"""
if language == 'en':
return {
'caption': 'Every moment tells a story worth sharing. The world around us is filled with beauty waiting to be discovered. Take a pause and appreciate the details that make life extraordinary. What caught your eye today? ✨',
'hashtags': ['photography', 'daily', 'lifestyle', 'moment', 'capture'],
'tone': 'casual',
'platform': platform
}
else:
return {
'caption': '每個瞬間都值得被記錄與分享。生活中充滿了等待被發現的美好細節。停下腳步,用心感受周遭的一切。今天什麼畫面觸動了你的心?✨',
'hashtags': ['攝影', '日常', '生活', '瞬間', '分享'],
'tone': 'casual',
'platform': platform
}
print("✓ CaptionGenerationManager (with Auto* classes for flexible model support) defined")