import torch import math from PIL import Image from typing import Dict, List, Tuple from rapidfuzz import fuzz from prompt_library_manager import PromptLibraryManager from brand_detection_optimizer import BrandDetectionOptimizer class BrandRecognitionManager: """Multi-modal brand recognition with detailed prompts (Visual + Text)""" def __init__(self, clip_manager, ocr_manager, prompt_library=None): self.clip_manager = clip_manager self.ocr_manager = ocr_manager self.prompt_library = prompt_library self.flat_brands = prompt_library.get_all_brands() # Initialize optimizer for smart brand detection self.optimizer = BrandDetectionOptimizer(clip_manager, ocr_manager, prompt_library) print(f"✓ Brand Recognition Manager loaded with {len(self.flat_brands)} brands (with optimizer)") def recognize_brand(self, image_region: Image.Image, full_image: Image.Image, region_bbox: List[int] = None) -> List[Tuple[str, float, List[int]]]: """Recognize brands using detailed context-aware prompts Args: image_region: Cropped region containing potential brand full_image: Full image for OCR region_bbox: Bounding box [x1, y1, x2, y2] for visualization Returns: List of (brand_name, confidence, bbox) tuples """ # Step 1: Classify region context region_context = self._classify_region_context(image_region) print(f" [DEBUG] Region context classified as: {region_context}") # Step 2: Use context-specific OpenCLIP prompts brand_scores = {} for brand_name, brand_info in self.flat_brands.items(): # Get best matching context for this brand best_context = self._match_region_to_brand_context(region_context, brand_info['region_contexts']) if best_context and best_context in brand_info['openclip_prompts']: # Use context-specific prompts prompts = brand_info['openclip_prompts'][best_context] visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) # Average scores from all prompts avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 else: # Fallback to strong cues prompts = brand_info['strong_cues'][:5] # Top 5 strong cues visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 brand_scores[brand_name] = avg_score # Step 2.5: Multi-scale visual matching for better robustness brand_scores = self._multi_scale_visual_matching(image_region, brand_scores) # Step 3: OCR text matching with brand-optimized preprocessing ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) text_matches = self._fuzzy_text_matching(ocr_results) print(f" [DEBUG] OCR found {len(ocr_results)} text regions") if text_matches: print(f" [DEBUG] OCR brand matches: {text_matches}") # Step 4: Adaptive weighted fusion (dynamic weights per brand) final_scores = {} for brand_name in self.flat_brands.keys(): visual_score = brand_scores.get(brand_name, 0.0) text_score, ocr_conf = text_matches.get(brand_name, (0.0, 0.0)) # Calculate adaptive weights based on brand characteristics visual_weight, text_weight, ocr_weight = self._calculate_adaptive_weights( brand_name, visual_score, text_score, ocr_conf ) # Weighted fusion with adaptive weights final_score = ( visual_weight * self._scale_visual(visual_score) + text_weight * text_score + ocr_weight * ocr_conf ) final_scores[brand_name] = final_score sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:5] print(f" [DEBUG] Top 5 brand scores:") for brand, score in sorted_scores: print(f" {brand}: {score:.4f} (visual={brand_scores.get(brand, 0):.4f}, text={text_matches.get(brand, (0, 0))[0]:.4f})") # Return confident matches with bounding boxes confident_brands = [] for brand_name, score in final_scores.items(): if score > 0.10: confident_brands.append((brand_name, score, region_bbox)) print(f" [DEBUG] ✓ Brand detected: {brand_name} (confidence: {score:.4f})") confident_brands.sort(key=lambda x: x[1], reverse=True) if not confident_brands: print(f" [DEBUG] ✗ No brands passed threshold 0.10") return confident_brands def _classify_region_context(self, image_region: Image.Image) -> str: """Classify what type of region this is (bag_panel, shoe_side, etc.)""" context_labels = [ 'bag panel with pattern', 'luggage surface with branding', 'luxury trunk with monogram pattern', 'vintage travel trunk with hardware', 'shoe side view', 'device back cover', 'apparel chest area', 'belt buckle', 'storefront sign', 'product tag or label', 'wallet surface', 'perfume bottle', 'watch dial or face', 'car front grille', 'laptop lid' ] scores = self.clip_manager.classify_zero_shot(image_region, context_labels) # Map to simplified contexts context_mapping = { 'bag panel with pattern': 'bag_panel', 'luggage surface with branding': 'luggage_surface', 'luxury trunk with monogram pattern': 'trunk_body', 'vintage travel trunk with hardware': 'trunk_body', 'shoe side view': 'shoe_side', 'device back cover': 'device_back', 'apparel chest area': 'apparel_chest', 'belt buckle': 'belt_buckle', 'storefront sign': 'storefront', 'product tag or label': 'product_tag', 'wallet surface': 'wallet', 'perfume bottle': 'perfume_bottle', 'watch dial or face': 'watch_dial', 'car front grille': 'car_front', 'laptop lid': 'laptop_lid' } top_context = max(scores.items(), key=lambda x: x[1])[0] return context_mapping.get(top_context, 'unknown') def _match_region_to_brand_context(self, region_context: str, brand_contexts: List[str]) -> str: """Match detected region context to brand's available contexts""" if region_context in brand_contexts: return region_context # Fuzzy matching for brand_context in brand_contexts: if region_context.split('_')[0] in brand_context: return brand_context return None def _fuzzy_text_matching(self, ocr_results: List[Dict]) -> Dict[str, Tuple[float, float]]: """Fuzzy text matching using brand aliases (optimized for logo text)""" matches = {} for ocr_item in ocr_results: text = ocr_item['text'] conf = ocr_item['confidence'] for brand_name, brand_info in self.flat_brands.items(): # Check all aliases all_names = [brand_name] + brand_info.get('aliases', []) for alias in all_names: ratio = fuzz.ratio(text, alias) / 100.0 if ratio > 0.70: # Lowered threshold for better recall if brand_name not in matches or ratio > matches[brand_name][0]: matches[brand_name] = (ratio, conf) return matches def _scale_visual(self, score: float) -> float: """Scale visual score using sigmoid""" return 1 / (1 + math.exp(-10 * (score - 0.5))) def _calculate_adaptive_weights(self, brand_name: str, visual_score: float, text_score: float, ocr_conf: float) -> tuple: """ Calculate adaptive weights based on brand characteristics and signal strengths Args: brand_name: Name of the brand visual_score: Visual similarity score text_score: Text matching score ocr_conf: OCR confidence Returns: Tuple of (visual_weight, text_weight, ocr_weight) """ brand_info = self.prompt_library.get_brand_prompts(brand_name) if not brand_info: # Default balanced weights return 0.50, 0.30, 0.20 # Base weights based on brand characteristics if brand_info.get('visual_distinctive', False): # Visually distinctive brands (LV, Burberry) visual_weight = 0.65 text_weight = 0.20 ocr_weight = 0.15 elif brand_info.get('text_prominent', False): # Text-prominent brands (Nike, Adidas) visual_weight = 0.30 text_weight = 0.30 ocr_weight = 0.40 else: # Balanced for general brands visual_weight = 0.50 text_weight = 0.30 ocr_weight = 0.20 # Dynamic adjustment based on signal strength # If visual signal is very strong, boost its weight if visual_score > 0.7: boost = 0.10 visual_weight += boost text_weight -= boost * 0.5 ocr_weight -= boost * 0.5 # If OCR has very high confidence, boost its weight if ocr_conf > 0.85: boost = 0.10 ocr_weight += boost visual_weight -= boost * 0.6 text_weight -= boost * 0.4 # If text match is very strong, boost its weight if text_score > 0.80: boost = 0.08 text_weight += boost visual_weight -= boost * 0.5 ocr_weight -= boost * 0.5 # Normalize weights to sum to 1 total = visual_weight + text_weight + ocr_weight return visual_weight / total, text_weight / total, ocr_weight / total def _multi_scale_visual_matching(self, image_region: Image.Image, initial_scores: Dict[str, float]) -> Dict[str, float]: """ Apply multi-scale matching to improve robustness Args: image_region: Image region to analyze initial_scores: Initial brand scores from single-scale matching Returns: Updated brand scores with multi-scale matching """ scales = [0.8, 1.0, 1.2] # Three scales multi_scale_scores = {brand: [] for brand in initial_scores.keys()} for scale in scales: # Resize image new_width = int(image_region.width * scale) new_height = int(image_region.height * scale) # Ensure minimum size if new_width < 50 or new_height < 50: continue try: scaled_img = image_region.resize((new_width, new_height), Image.Resampling.LANCZOS) # Re-run classification on each brand's prompts for brand_name, brand_info in self.flat_brands.items(): # Get context-specific prompts best_context = self._match_region_to_brand_context( 'bag_panel', # Default context, ideally should be passed as parameter brand_info.get('region_contexts', []) ) if best_context and best_context in brand_info.get('openclip_prompts', {}): prompts = brand_info['openclip_prompts'][best_context] visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 else: prompts = brand_info.get('strong_cues', [])[:3] visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 multi_scale_scores[brand_name].append(avg_score) except Exception as e: # Skip this scale if error occurs continue # Aggregate multi-scale scores (use max score across scales) final_scores = {} for brand_name, scores in multi_scale_scores.items(): if scores: final_scores[brand_name] = max(scores) else: final_scores[brand_name] = initial_scores.get(brand_name, 0.0) return final_scores def scan_full_image_for_brands(self, full_image: Image.Image, exclude_bboxes: List[List[int]] = None, saliency_regions: List[Dict] = None) -> List[Tuple[str, float, List[int]]]: """ 智能全圖品牌掃描 - 性能優化版本 使用預篩選和智能區域選擇大幅減少檢測時間 Args: full_image: PIL Image (full image) exclude_bboxes: List of bboxes to exclude (already detected) saliency_regions: Saliency detection results for smart region selection Returns: List of (brand_name, confidence, bbox) tuples """ if exclude_bboxes is None: exclude_bboxes = [] detected_brands = {} # brand_name -> (confidence, bbox) img_width, img_height = full_image.size # OPTIMIZATION 1: 快速品牌預篩選 likely_brands = self.optimizer.quick_brand_prescreening(full_image) print(f" Quick prescreening found {len(likely_brands)} potential brands") # OPTIMIZATION 2: 智能區域選擇(只掃描有意義的區域) regions_to_scan = self.optimizer.smart_region_selection(full_image, saliency_regions or []) print(f" Scanning {len(regions_to_scan)} intelligent regions") # 掃描選定的區域 for region_bbox in regions_to_scan: x1, y1, x2, y2 = region_bbox # 跳過已檢測區域 if self._bbox_overlap(list(region_bbox), exclude_bboxes): continue # 提取區域 region = full_image.crop(region_bbox) # 只檢測預篩選的品牌(而非所有20+品牌) for brand_name in likely_brands: brand_info = self.flat_brands.get(brand_name) if not brand_info: continue # only use strong_cues strong_cues = brand_info.get('strong_cues', [])[:5] # Top 5 if not strong_cues: continue visual_scores = self.clip_manager.classify_zero_shot(region, strong_cues) avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 # OCR 增強 ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) boosted_score = self.optimizer.compute_brand_confidence_boost( brand_name, ocr_results, avg_score ) # 極度寬鬆的閾值以最大化檢測率 if boosted_score > 0.08: # 降低到 0.08 # 更新最佳結果 if brand_name not in detected_brands or boosted_score > detected_brands[brand_name][0]: detected_brands[brand_name] = (boosted_score, list(region_bbox)) # 轉換為列表格式 final_brands = [ (brand_name, confidence, bbox) for brand_name, (confidence, bbox) in detected_brands.items() ] # 按信心度排序 final_brands.sort(key=lambda x: x[1], reverse=True) return final_brands[:5] # 返回前5個 def _bbox_overlap(self, bbox1: List[int], bbox_list: List[List[int]]) -> bool: """Check if bbox1 overlaps significantly with any bbox in bbox_list""" if not bbox_list: return False x1_1, y1_1, x2_1, y2_1 = bbox1 for bbox2 in bbox_list: if bbox2 is None: continue x1_2, y1_2, x2_2, y2_2 = bbox2 # Calculate intersection x_left = max(x1_1, x1_2) y_top = max(y1_1, y1_2) x_right = min(x2_1, x2_2) y_bottom = min(y2_1, y2_2) if x_right < x_left or y_bottom < y_top: continue intersection_area = (x_right - x_left) * (y_bottom - y_top) bbox1_area = (x2_1 - x1_1) * (y2_1 - y1_1) # 如果重疊超過 30%,視為重疊 if intersection_area / bbox1_area > 0.3: return True return False print("✓ BrandRecognitionManager (with full-image scan for commercial use) defined")