|
|
|
|
|
import re |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
from prompt_library_manager import PromptLibraryManager |
|
|
|
|
|
class OutputProcessingManager: |
|
|
""" |
|
|
輸出驗證、格式化與智能標籤生成 |
|
|
整合 PromptLibraryManager 提供商業級標籤生成 |
|
|
""" |
|
|
|
|
|
def __init__(self, prompt_library: PromptLibraryManager = None): |
|
|
""" |
|
|
Args: |
|
|
prompt_library: PromptLibraryManager 實例(可選,會自動創建) |
|
|
""" |
|
|
self.profanity_filter = set([]) |
|
|
|
|
|
self.max_lengths = { |
|
|
'instagram': 2200, |
|
|
'tiktok': 100, |
|
|
'xiaohongshu': 500 |
|
|
} |
|
|
|
|
|
|
|
|
if prompt_library is None: |
|
|
self.prompt_library = PromptLibraryManager() |
|
|
else: |
|
|
self.prompt_library = prompt_library |
|
|
|
|
|
|
|
|
self.landmark_keywords = self._init_landmark_keywords() |
|
|
|
|
|
print("✓ OutputProcessingManager (with integrated PromptLibraryManager) initialized") |
|
|
|
|
|
def _init_landmark_keywords(self) -> Dict[str, List[str]]: |
|
|
""" |
|
|
初始化地標檢測關鍵字映射 |
|
|
用於從檢測到的物體和場景中推測可能的地標 |
|
|
""" |
|
|
return { |
|
|
'Big Ben': ['clock tower', 'tower', 'bridge', 'palace', 'gothic'], |
|
|
'Eiffel Tower': ['tower', 'iron', 'landmark', 'lattice'], |
|
|
'Statue of Liberty': ['statue', 'monument', 'harbor', 'torch'], |
|
|
'Golden Gate Bridge': ['bridge', 'suspension', 'orange', 'bay'], |
|
|
'Sydney Opera House': ['opera', 'building', 'harbor', 'shell'], |
|
|
'Taj Mahal': ['palace', 'dome', 'monument', 'marble'], |
|
|
'Colosseum': ['arena', 'amphitheater', 'ruins', 'ancient'], |
|
|
'Pyramids of Giza': ['pyramid', 'desert', 'ancient', 'monument'], |
|
|
'Burj Khalifa': ['skyscraper', 'tower', 'building', 'tall'], |
|
|
'Tokyo Tower': ['tower', 'lattice', 'red'], |
|
|
'Taipei 101': ['skyscraper', 'tower', 'building'], |
|
|
|
|
|
} |
|
|
|
|
|
def detect_landmark(self, detections: List[Dict], scene_info: Dict) -> Optional[str]: |
|
|
""" |
|
|
從檢測結果中推測可能的地標 |
|
|
|
|
|
Args: |
|
|
detections: YOLO 檢測結果 |
|
|
scene_info: 場景分析結果 |
|
|
|
|
|
Returns: |
|
|
推測的地標名稱,若無法推測則返回 None |
|
|
""" |
|
|
detected_objects = [d.get('class_name', '').lower() for d in detections] |
|
|
|
|
|
|
|
|
scene_keywords = [] |
|
|
urban_scene = scene_info.get('urban', {}).get('top', '') |
|
|
if urban_scene: |
|
|
scene_keywords.append(urban_scene.lower()) |
|
|
|
|
|
all_keywords = detected_objects + scene_keywords |
|
|
|
|
|
|
|
|
scores = {} |
|
|
for landmark, keywords in self.landmark_keywords.items(): |
|
|
match_count = sum(1 for obj in all_keywords |
|
|
if any(kw in obj for kw in keywords)) |
|
|
if match_count > 0: |
|
|
scores[landmark] = match_count |
|
|
|
|
|
|
|
|
if scores: |
|
|
best_landmark = max(scores.items(), key=lambda x: x[1]) |
|
|
if best_landmark[1] >= 2: |
|
|
return best_landmark[0] |
|
|
|
|
|
return None |
|
|
|
|
|
def generate_smart_hashtags(self, detections: List[Dict], scene_info: Dict, |
|
|
brands: List, platform: str, language: str) -> List[str]: |
|
|
""" |
|
|
智能標籤生成:整合品牌、地標、場景的標籤 |
|
|
|
|
|
Args: |
|
|
detections: 檢測到的物體列表 |
|
|
scene_info: 場景分析結果 |
|
|
brands: 檢測到的品牌列表 |
|
|
platform: 平台名稱 |
|
|
language: 語言 ('zh', 'en', 或 'zh-en') |
|
|
|
|
|
Returns: |
|
|
智能生成的 hashtag 列表(最多 10 個) |
|
|
""" |
|
|
hashtags = [] |
|
|
|
|
|
|
|
|
detected_landmark = self.detect_landmark(detections, scene_info) |
|
|
if detected_landmark: |
|
|
landmark_tags = self.prompt_library.landmark_prompts.get_hashtags( |
|
|
detected_landmark, language |
|
|
) |
|
|
hashtags.extend(landmark_tags[:5]) |
|
|
|
|
|
|
|
|
if brands: |
|
|
for brand in brands[:3]: |
|
|
brand_name = brand[0] if isinstance(brand, tuple) else brand |
|
|
brand_tags = self.prompt_library.brand_prompts.get_hashtags( |
|
|
brand_name, language |
|
|
) |
|
|
hashtags.extend(brand_tags[:3]) |
|
|
|
|
|
|
|
|
scene_category = self._detect_scene_category(scene_info, detections) |
|
|
if scene_category: |
|
|
scene_tags = self.prompt_library.scene_prompts.get_hashtags( |
|
|
scene_category, language |
|
|
) |
|
|
hashtags.extend(scene_tags[:4]) |
|
|
|
|
|
|
|
|
composition_tags = self._get_composition_hashtags(scene_info, language) |
|
|
hashtags.extend(composition_tags) |
|
|
|
|
|
|
|
|
platform_tags = self._get_platform_hashtags(platform, language) |
|
|
hashtags.extend(platform_tags) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
unique_hashtags = [] |
|
|
for tag in hashtags: |
|
|
if tag not in seen and tag: |
|
|
seen.add(tag) |
|
|
unique_hashtags.append(tag) |
|
|
|
|
|
|
|
|
return unique_hashtags[:10] |
|
|
|
|
|
def _detect_scene_category(self, scene_info: Dict, detections: List[Dict]) -> Optional[str]: |
|
|
""" |
|
|
檢測場景類別 |
|
|
|
|
|
Returns: |
|
|
場景類別名稱 ('urban', 'nature', 'indoor', 'food', etc.) |
|
|
""" |
|
|
|
|
|
object_classes = [d.get('class_name', '').lower() for d in detections] |
|
|
|
|
|
|
|
|
food_keywords = ['sandwich', 'pizza', 'cake', 'food', 'plate', 'bowl', 'cup', 'bottle'] |
|
|
if any(kw in obj for kw in food_keywords for obj in object_classes): |
|
|
return 'food' |
|
|
|
|
|
|
|
|
nature_keywords = ['tree', 'mountain', 'water', 'sky', 'beach', 'ocean'] |
|
|
if any(kw in obj for kw in nature_keywords for obj in object_classes): |
|
|
return 'nature' |
|
|
|
|
|
|
|
|
urban_scene = scene_info.get('urban', {}).get('top', '') |
|
|
if urban_scene and ('canyon' in urban_scene or 'street' in urban_scene or 'building' in urban_scene): |
|
|
return 'urban' |
|
|
|
|
|
|
|
|
indoor_keywords = ['chair', 'table', 'couch', 'bed', 'desk'] |
|
|
if any(kw in obj for kw in indoor_keywords for obj in object_classes): |
|
|
return 'indoor' |
|
|
|
|
|
return 'urban' |
|
|
|
|
|
def _get_composition_hashtags(self, scene_info: Dict, language: str) -> List[str]: |
|
|
""" |
|
|
根據構圖類型生成標籤 |
|
|
""" |
|
|
hashtags = [] |
|
|
|
|
|
composition = scene_info.get('urban', {}).get('top', '') |
|
|
|
|
|
|
|
|
if 'canyon' in composition or 'skyscraper' in composition: |
|
|
if language == 'zh': |
|
|
hashtags.extend(['城市峽谷', '城市風景']) |
|
|
elif language == 'en': |
|
|
hashtags.extend(['UrbanCanyon', 'Cityscape']) |
|
|
else: |
|
|
hashtags.extend(['城市峽谷', 'UrbanCanyon']) |
|
|
|
|
|
|
|
|
if language == 'zh': |
|
|
hashtags.append('攝影日常') |
|
|
elif language == 'en': |
|
|
hashtags.append('Photography') |
|
|
else: |
|
|
hashtags.extend(['攝影日常', 'Photography']) |
|
|
|
|
|
return hashtags |
|
|
|
|
|
def _get_platform_hashtags(self, platform: str, language: str) -> List[str]: |
|
|
""" |
|
|
根據平台生成特定標籤 |
|
|
""" |
|
|
hashtags = [] |
|
|
|
|
|
if platform == 'instagram': |
|
|
if language == 'zh': |
|
|
hashtags.append('IG日常') |
|
|
elif language == 'en': |
|
|
hashtags.append('InstaDaily') |
|
|
else: |
|
|
hashtags.extend(['IG日常', 'InstaDaily']) |
|
|
|
|
|
elif platform == 'tiktok': |
|
|
if language == 'zh': |
|
|
hashtags.append('抖音') |
|
|
elif language == 'en': |
|
|
hashtags.append('TikTok') |
|
|
else: |
|
|
hashtags.extend(['抖音', 'TikTok']) |
|
|
|
|
|
elif platform == 'xiaohongshu': |
|
|
hashtags.extend(['小紅書', '分享日常']) |
|
|
|
|
|
return hashtags |
|
|
|
|
|
def validate_output(self, output: Dict, platform: str, |
|
|
detections: List[Dict] = None, scene_info: Dict = None, |
|
|
brands: List = None, language: str = 'en') -> Tuple[bool, str]: |
|
|
""" |
|
|
驗證輸出格式和內容(含標籤自動補充) |
|
|
|
|
|
Args: |
|
|
output: 生成的標題字典 |
|
|
platform: 平台名稱 |
|
|
detections: 檢測結果(用於標籤補充) |
|
|
scene_info: 場景資訊(用於標籤補充) |
|
|
brands: 品牌列表(用於標籤補充) |
|
|
language: 語言 |
|
|
|
|
|
Returns: |
|
|
(是否通過驗證, 驗證訊息) |
|
|
""" |
|
|
|
|
|
required_fields = ['caption', 'hashtags', 'tone', 'platform'] |
|
|
if not all(field in output for field in required_fields): |
|
|
return False, "Missing required fields" |
|
|
|
|
|
|
|
|
max_length = self.max_lengths.get(platform, 2200) |
|
|
if len(output['caption']) > max_length: |
|
|
output['caption'] = output['caption'][:max_length-3] + '...' |
|
|
|
|
|
|
|
|
if self._contains_profanity(output['caption']): |
|
|
return False, "Contains inappropriate content" |
|
|
|
|
|
|
|
|
output['hashtags'] = self._validate_hashtags(output['hashtags']) |
|
|
|
|
|
|
|
|
min_hashtags = 5 |
|
|
if len(output['hashtags']) < min_hashtags: |
|
|
|
|
|
if detections is not None and scene_info is not None: |
|
|
additional_tags = self.generate_smart_hashtags( |
|
|
detections, scene_info, brands or [], platform, language |
|
|
) |
|
|
|
|
|
for tag in additional_tags: |
|
|
if tag not in output['hashtags'] and len(output['hashtags']) < 10: |
|
|
output['hashtags'].append(tag) |
|
|
|
|
|
print(f" [AUTO-補充] 標籤數量不足 ({len(output['hashtags'])} < {min_hashtags}),已自動補充至 {len(output['hashtags'])} 個") |
|
|
|
|
|
|
|
|
if '#' in output['caption']: |
|
|
|
|
|
output['caption'] = re.sub(r'#\w+', '', output['caption']).strip() |
|
|
|
|
|
return True, "Validation passed" |
|
|
|
|
|
def _contains_profanity(self, text: str) -> bool: |
|
|
"""檢查不當內容""" |
|
|
text_lower = text.lower() |
|
|
for word in self.profanity_filter: |
|
|
if word in text_lower: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _validate_hashtags(self, hashtags: List[str]) -> List[str]: |
|
|
""" |
|
|
驗證並清理 hashtags |
|
|
|
|
|
Args: |
|
|
hashtags: 原始 hashtag 列表 |
|
|
|
|
|
Returns: |
|
|
清理後的 hashtag 列表 |
|
|
""" |
|
|
cleaned = [] |
|
|
for tag in hashtags: |
|
|
|
|
|
tag = tag.lstrip('#') |
|
|
|
|
|
|
|
|
tag = re.sub(r'[^\w\u4e00-\u9fff]', '', tag) |
|
|
|
|
|
|
|
|
if tag and tag not in cleaned: |
|
|
cleaned.append(tag) |
|
|
|
|
|
return cleaned[:10] |
|
|
|
|
|
def format_for_platform(self, caption: Dict, platform: str) -> str: |
|
|
""" |
|
|
根據平台格式化輸出 |
|
|
|
|
|
Args: |
|
|
caption: 標題字典 |
|
|
platform: 平台名稱 |
|
|
|
|
|
Returns: |
|
|
格式化的字串 |
|
|
""" |
|
|
formatted = f"{caption['caption']}\n\n" |
|
|
|
|
|
if platform == 'xiaohongshu': |
|
|
|
|
|
formatted += ' '.join([f"#{tag}" for tag in caption['hashtags']]) |
|
|
else: |
|
|
|
|
|
formatted += '\n' + ' '.join([f"#{tag}" for tag in caption['hashtags']]) |
|
|
|
|
|
return formatted |
|
|
|
|
|
print("✓ OutputProcessingManager (V3 with PromptLibraryManager integration) defined") |
|
|
|