Pixcribe / output_processing_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import re
from typing import Dict, List, Tuple, Optional
from prompt_library_manager import PromptLibraryManager
class OutputProcessingManager:
"""
輸出驗證、格式化與智能標籤生成
整合 PromptLibraryManager 提供商業級標籤生成
"""
def __init__(self, prompt_library: PromptLibraryManager = None):
"""
Args:
prompt_library: PromptLibraryManager 實例(可選,會自動創建)
"""
self.profanity_filter = set([])
self.max_lengths = {
'instagram': 2200,
'tiktok': 100,
'xiaohongshu': 500
}
# 初始化或使用提供的 PromptLibraryManager
if prompt_library is None:
self.prompt_library = PromptLibraryManager()
else:
self.prompt_library = prompt_library
# 地標檢測關鍵字(用於簡單的地標識別)
self.landmark_keywords = self._init_landmark_keywords()
print("✓ OutputProcessingManager (with integrated PromptLibraryManager) initialized")
def _init_landmark_keywords(self) -> Dict[str, List[str]]:
"""
初始化地標檢測關鍵字映射
用於從檢測到的物體和場景中推測可能的地標
"""
return {
'Big Ben': ['clock tower', 'tower', 'bridge', 'palace', 'gothic'],
'Eiffel Tower': ['tower', 'iron', 'landmark', 'lattice'],
'Statue of Liberty': ['statue', 'monument', 'harbor', 'torch'],
'Golden Gate Bridge': ['bridge', 'suspension', 'orange', 'bay'],
'Sydney Opera House': ['opera', 'building', 'harbor', 'shell'],
'Taj Mahal': ['palace', 'dome', 'monument', 'marble'],
'Colosseum': ['arena', 'amphitheater', 'ruins', 'ancient'],
'Pyramids of Giza': ['pyramid', 'desert', 'ancient', 'monument'],
'Burj Khalifa': ['skyscraper', 'tower', 'building', 'tall'],
'Tokyo Tower': ['tower', 'lattice', 'red'],
'Taipei 101': ['skyscraper', 'tower', 'building'],
# 可以擴展更多
}
def detect_landmark(self, detections: List[Dict], scene_info: Dict) -> Optional[str]:
"""
從檢測結果中推測可能的地標
Args:
detections: YOLO 檢測結果
scene_info: 場景分析結果
Returns:
推測的地標名稱,若無法推測則返回 None
"""
detected_objects = [d.get('class_name', '').lower() for d in detections]
# 從場景資訊中提取更多線索
scene_keywords = []
urban_scene = scene_info.get('urban', {}).get('top', '')
if urban_scene:
scene_keywords.append(urban_scene.lower())
all_keywords = detected_objects + scene_keywords
# 計算每個地標的匹配分數
scores = {}
for landmark, keywords in self.landmark_keywords.items():
match_count = sum(1 for obj in all_keywords
if any(kw in obj for kw in keywords))
if match_count > 0:
scores[landmark] = match_count
# 返回得分最高的地標(至少需要 2 個匹配)
if scores:
best_landmark = max(scores.items(), key=lambda x: x[1])
if best_landmark[1] >= 2:
return best_landmark[0]
return None
def generate_smart_hashtags(self, detections: List[Dict], scene_info: Dict,
brands: List, platform: str, language: str) -> List[str]:
"""
智能標籤生成:整合品牌、地標、場景的標籤
Args:
detections: 檢測到的物體列表
scene_info: 場景分析結果
brands: 檢測到的品牌列表
platform: 平台名稱
language: 語言 ('zh', 'en', 或 'zh-en')
Returns:
智能生成的 hashtag 列表(最多 10 個)
"""
hashtags = []
# 1. 檢測地標(最高優先級)
detected_landmark = self.detect_landmark(detections, scene_info)
if detected_landmark:
landmark_tags = self.prompt_library.landmark_prompts.get_hashtags(
detected_landmark, language
)
hashtags.extend(landmark_tags[:5]) # 地標標籤限制 5 個
# 2. 品牌標籤(高優先級)
if brands:
for brand in brands[:3]: # 最多 3 個品牌
brand_name = brand[0] if isinstance(brand, tuple) else brand
brand_tags = self.prompt_library.brand_prompts.get_hashtags(
brand_name, language
)
hashtags.extend(brand_tags[:3]) # 每個品牌最多 3 個標籤
# 3. 場景標籤(中優先級)
scene_category = self._detect_scene_category(scene_info, detections)
if scene_category:
scene_tags = self.prompt_library.scene_prompts.get_hashtags(
scene_category, language
)
hashtags.extend(scene_tags[:4])
# 4. 構圖特定標籤
composition_tags = self._get_composition_hashtags(scene_info, language)
hashtags.extend(composition_tags)
# 5. 平台特定標籤
platform_tags = self._get_platform_hashtags(platform, language)
hashtags.extend(platform_tags)
# 去重並保持順序(地標 > 品牌 > 場景 > 構圖 > 平台)
seen = set()
unique_hashtags = []
for tag in hashtags:
if tag not in seen and tag: # 確保標籤不為空
seen.add(tag)
unique_hashtags.append(tag)
# 返回前 10 個
return unique_hashtags[:10]
def _detect_scene_category(self, scene_info: Dict, detections: List[Dict]) -> Optional[str]:
"""
檢測場景類別
Returns:
場景類別名稱 ('urban', 'nature', 'indoor', 'food', etc.)
"""
# 檢查物體類別來判斷場景
object_classes = [d.get('class_name', '').lower() for d in detections]
# 食物場景
food_keywords = ['sandwich', 'pizza', 'cake', 'food', 'plate', 'bowl', 'cup', 'bottle']
if any(kw in obj for kw in food_keywords for obj in object_classes):
return 'food'
# 自然場景
nature_keywords = ['tree', 'mountain', 'water', 'sky', 'beach', 'ocean']
if any(kw in obj for kw in nature_keywords for obj in object_classes):
return 'nature'
# 城市場景(默認)
urban_scene = scene_info.get('urban', {}).get('top', '')
if urban_scene and ('canyon' in urban_scene or 'street' in urban_scene or 'building' in urban_scene):
return 'urban'
# 室內場景
indoor_keywords = ['chair', 'table', 'couch', 'bed', 'desk']
if any(kw in obj for kw in indoor_keywords for obj in object_classes):
return 'indoor'
return 'urban' # 默認城市場景
def _get_composition_hashtags(self, scene_info: Dict, language: str) -> List[str]:
"""
根據構圖類型生成標籤
"""
hashtags = []
composition = scene_info.get('urban', {}).get('top', '')
# 城市峽谷
if 'canyon' in composition or 'skyscraper' in composition:
if language == 'zh':
hashtags.extend(['城市峽谷', '城市風景'])
elif language == 'en':
hashtags.extend(['UrbanCanyon', 'Cityscape'])
else: # bilingual
hashtags.extend(['城市峽谷', 'UrbanCanyon'])
# 攝影類型
if language == 'zh':
hashtags.append('攝影日常')
elif language == 'en':
hashtags.append('Photography')
else:
hashtags.extend(['攝影日常', 'Photography'])
return hashtags
def _get_platform_hashtags(self, platform: str, language: str) -> List[str]:
"""
根據平台生成特定標籤
"""
hashtags = []
if platform == 'instagram':
if language == 'zh':
hashtags.append('IG日常')
elif language == 'en':
hashtags.append('InstaDaily')
else:
hashtags.extend(['IG日常', 'InstaDaily'])
elif platform == 'tiktok':
if language == 'zh':
hashtags.append('抖音')
elif language == 'en':
hashtags.append('TikTok')
else:
hashtags.extend(['抖音', 'TikTok'])
elif platform == 'xiaohongshu':
hashtags.extend(['小紅書', '分享日常'])
return hashtags
def validate_output(self, output: Dict, platform: str,
detections: List[Dict] = None, scene_info: Dict = None,
brands: List = None, language: str = 'en') -> Tuple[bool, str]:
"""
驗證輸出格式和內容(含標籤自動補充)
Args:
output: 生成的標題字典
platform: 平台名稱
detections: 檢測結果(用於標籤補充)
scene_info: 場景資訊(用於標籤補充)
brands: 品牌列表(用於標籤補充)
language: 語言
Returns:
(是否通過驗證, 驗證訊息)
"""
# 1. 結構驗證
required_fields = ['caption', 'hashtags', 'tone', 'platform']
if not all(field in output for field in required_fields):
return False, "Missing required fields"
# 2. 長度驗證
max_length = self.max_lengths.get(platform, 2200)
if len(output['caption']) > max_length:
output['caption'] = output['caption'][:max_length-3] + '...'
# 3. 內容過濾
if self._contains_profanity(output['caption']):
return False, "Contains inappropriate content"
# 4. 標籤驗證
output['hashtags'] = self._validate_hashtags(output['hashtags'])
# 🆕 5. 標籤數量檢查與自動補充(商業級功能)
min_hashtags = 5 # 最低標籤數量要求
if len(output['hashtags']) < min_hashtags:
# 如果提供了檢測資訊,自動補充標籤
if detections is not None and scene_info is not None:
additional_tags = self.generate_smart_hashtags(
detections, scene_info, brands or [], platform, language
)
# 補充標籤(避免重複)
for tag in additional_tags:
if tag not in output['hashtags'] and len(output['hashtags']) < 10:
output['hashtags'].append(tag)
print(f" [AUTO-補充] 標籤數量不足 ({len(output['hashtags'])} < {min_hashtags}),已自動補充至 {len(output['hashtags'])} 個")
# 6. 確保標題中沒有 hashtag 符號
if '#' in output['caption']:
# 移除標題中的 hashtag
output['caption'] = re.sub(r'#\w+', '', output['caption']).strip()
return True, "Validation passed"
def _contains_profanity(self, text: str) -> bool:
"""檢查不當內容"""
text_lower = text.lower()
for word in self.profanity_filter:
if word in text_lower:
return True
return False
def _validate_hashtags(self, hashtags: List[str]) -> List[str]:
"""
驗證並清理 hashtags
Args:
hashtags: 原始 hashtag 列表
Returns:
清理後的 hashtag 列表
"""
cleaned = []
for tag in hashtags:
# 移除 # 符號
tag = tag.lstrip('#')
# 保留中文、英文、數字
tag = re.sub(r'[^\w\u4e00-\u9fff]', '', tag)
# 確保不為空且不重複
if tag and tag not in cleaned:
cleaned.append(tag)
return cleaned[:10] # 最多 10 個
def format_for_platform(self, caption: Dict, platform: str) -> str:
"""
根據平台格式化輸出
Args:
caption: 標題字典
platform: 平台名稱
Returns:
格式化的字串
"""
formatted = f"{caption['caption']}\n\n"
if platform == 'xiaohongshu':
# 小紅書:標籤直接接在標題後
formatted += ' '.join([f"#{tag}" for tag in caption['hashtags']])
else:
# Instagram/TikTok:標籤另起一行
formatted += '\n' + ' '.join([f"#{tag}" for tag in caption['hashtags']])
return formatted
print("✓ OutputProcessingManager (V3 with PromptLibraryManager integration) defined")