|
|
import torch |
|
|
import math |
|
|
from PIL import Image |
|
|
from typing import Dict, List, Tuple |
|
|
from rapidfuzz import fuzz |
|
|
from prompt_library_manager import PromptLibraryManager |
|
|
from brand_detection_optimizer import BrandDetectionOptimizer |
|
|
|
|
|
class BrandRecognitionManager: |
|
|
"""Multi-modal brand recognition with detailed prompts (Visual + Text)""" |
|
|
|
|
|
def __init__(self, clip_manager, ocr_manager, prompt_library=None): |
|
|
self.clip_manager = clip_manager |
|
|
self.ocr_manager = ocr_manager |
|
|
self.prompt_library = prompt_library |
|
|
self.flat_brands = prompt_library.get_all_brands() |
|
|
|
|
|
|
|
|
self.optimizer = BrandDetectionOptimizer(clip_manager, ocr_manager, prompt_library) |
|
|
|
|
|
print(f"✓ Brand Recognition Manager loaded with {len(self.flat_brands)} brands (with optimizer)") |
|
|
|
|
|
def recognize_brand(self, image_region: Image.Image, full_image: Image.Image, |
|
|
region_bbox: List[int] = None) -> List[Tuple[str, float, List[int]]]: |
|
|
"""Recognize brands using detailed context-aware prompts |
|
|
|
|
|
Args: |
|
|
image_region: Cropped region containing potential brand |
|
|
full_image: Full image for OCR |
|
|
region_bbox: Bounding box [x1, y1, x2, y2] for visualization |
|
|
|
|
|
Returns: |
|
|
List of (brand_name, confidence, bbox) tuples |
|
|
""" |
|
|
|
|
|
|
|
|
region_context = self._classify_region_context(image_region) |
|
|
print(f" [DEBUG] Region context classified as: {region_context}") |
|
|
|
|
|
|
|
|
brand_scores = {} |
|
|
|
|
|
for brand_name, brand_info in self.flat_brands.items(): |
|
|
|
|
|
best_context = self._match_region_to_brand_context(region_context, brand_info['region_contexts']) |
|
|
|
|
|
if best_context and best_context in brand_info['openclip_prompts']: |
|
|
|
|
|
prompts = brand_info['openclip_prompts'][best_context] |
|
|
visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) |
|
|
|
|
|
|
|
|
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
else: |
|
|
|
|
|
prompts = brand_info['strong_cues'][:5] |
|
|
visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts) |
|
|
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
|
|
|
brand_scores[brand_name] = avg_score |
|
|
|
|
|
|
|
|
brand_scores = self._multi_scale_visual_matching(image_region, brand_scores) |
|
|
|
|
|
|
|
|
ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) |
|
|
text_matches = self._fuzzy_text_matching(ocr_results) |
|
|
|
|
|
print(f" [DEBUG] OCR found {len(ocr_results)} text regions") |
|
|
if text_matches: |
|
|
print(f" [DEBUG] OCR brand matches: {text_matches}") |
|
|
|
|
|
|
|
|
final_scores = {} |
|
|
for brand_name in self.flat_brands.keys(): |
|
|
visual_score = brand_scores.get(brand_name, 0.0) |
|
|
text_score, ocr_conf = text_matches.get(brand_name, (0.0, 0.0)) |
|
|
|
|
|
|
|
|
visual_weight, text_weight, ocr_weight = self._calculate_adaptive_weights( |
|
|
brand_name, visual_score, text_score, ocr_conf |
|
|
) |
|
|
|
|
|
|
|
|
final_score = ( |
|
|
visual_weight * self._scale_visual(visual_score) + |
|
|
text_weight * text_score + |
|
|
ocr_weight * ocr_conf |
|
|
) |
|
|
final_scores[brand_name] = final_score |
|
|
|
|
|
sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:5] |
|
|
print(f" [DEBUG] Top 5 brand scores:") |
|
|
for brand, score in sorted_scores: |
|
|
print(f" {brand}: {score:.4f} (visual={brand_scores.get(brand, 0):.4f}, text={text_matches.get(brand, (0, 0))[0]:.4f})") |
|
|
|
|
|
|
|
|
confident_brands = [] |
|
|
for brand_name, score in final_scores.items(): |
|
|
if score > 0.10: |
|
|
confident_brands.append((brand_name, score, region_bbox)) |
|
|
print(f" [DEBUG] ✓ Brand detected: {brand_name} (confidence: {score:.4f})") |
|
|
|
|
|
confident_brands.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if not confident_brands: |
|
|
print(f" [DEBUG] ✗ No brands passed threshold 0.10") |
|
|
|
|
|
return confident_brands |
|
|
|
|
|
def _classify_region_context(self, image_region: Image.Image) -> str: |
|
|
"""Classify what type of region this is (bag_panel, shoe_side, etc.)""" |
|
|
context_labels = [ |
|
|
'bag panel with pattern', |
|
|
'luggage surface with branding', |
|
|
'luxury trunk with monogram pattern', |
|
|
'vintage travel trunk with hardware', |
|
|
'shoe side view', |
|
|
'device back cover', |
|
|
'apparel chest area', |
|
|
'belt buckle', |
|
|
'storefront sign', |
|
|
'product tag or label', |
|
|
'wallet surface', |
|
|
'perfume bottle', |
|
|
'watch dial or face', |
|
|
'car front grille', |
|
|
'laptop lid' |
|
|
] |
|
|
|
|
|
scores = self.clip_manager.classify_zero_shot(image_region, context_labels) |
|
|
|
|
|
|
|
|
context_mapping = { |
|
|
'bag panel with pattern': 'bag_panel', |
|
|
'luggage surface with branding': 'luggage_surface', |
|
|
'luxury trunk with monogram pattern': 'trunk_body', |
|
|
'vintage travel trunk with hardware': 'trunk_body', |
|
|
'shoe side view': 'shoe_side', |
|
|
'device back cover': 'device_back', |
|
|
'apparel chest area': 'apparel_chest', |
|
|
'belt buckle': 'belt_buckle', |
|
|
'storefront sign': 'storefront', |
|
|
'product tag or label': 'product_tag', |
|
|
'wallet surface': 'wallet', |
|
|
'perfume bottle': 'perfume_bottle', |
|
|
'watch dial or face': 'watch_dial', |
|
|
'car front grille': 'car_front', |
|
|
'laptop lid': 'laptop_lid' |
|
|
} |
|
|
|
|
|
top_context = max(scores.items(), key=lambda x: x[1])[0] |
|
|
return context_mapping.get(top_context, 'unknown') |
|
|
|
|
|
def _match_region_to_brand_context(self, region_context: str, brand_contexts: List[str]) -> str: |
|
|
"""Match detected region context to brand's available contexts""" |
|
|
if region_context in brand_contexts: |
|
|
return region_context |
|
|
|
|
|
for brand_context in brand_contexts: |
|
|
if region_context.split('_')[0] in brand_context: |
|
|
return brand_context |
|
|
return None |
|
|
|
|
|
def _fuzzy_text_matching(self, ocr_results: List[Dict]) -> Dict[str, Tuple[float, float]]: |
|
|
"""Fuzzy text matching using brand aliases (optimized for logo text)""" |
|
|
matches = {} |
|
|
|
|
|
for ocr_item in ocr_results: |
|
|
text = ocr_item['text'] |
|
|
conf = ocr_item['confidence'] |
|
|
|
|
|
for brand_name, brand_info in self.flat_brands.items(): |
|
|
|
|
|
all_names = [brand_name] + brand_info.get('aliases', []) |
|
|
|
|
|
for alias in all_names: |
|
|
ratio = fuzz.ratio(text, alias) / 100.0 |
|
|
if ratio > 0.70: |
|
|
if brand_name not in matches or ratio > matches[brand_name][0]: |
|
|
matches[brand_name] = (ratio, conf) |
|
|
|
|
|
return matches |
|
|
|
|
|
def _scale_visual(self, score: float) -> float: |
|
|
"""Scale visual score using sigmoid""" |
|
|
return 1 / (1 + math.exp(-10 * (score - 0.5))) |
|
|
|
|
|
def _calculate_adaptive_weights(self, brand_name: str, visual_score: float, |
|
|
text_score: float, ocr_conf: float) -> tuple: |
|
|
""" |
|
|
Calculate adaptive weights based on brand characteristics and signal strengths |
|
|
|
|
|
Args: |
|
|
brand_name: Name of the brand |
|
|
visual_score: Visual similarity score |
|
|
text_score: Text matching score |
|
|
ocr_conf: OCR confidence |
|
|
|
|
|
Returns: |
|
|
Tuple of (visual_weight, text_weight, ocr_weight) |
|
|
""" |
|
|
brand_info = self.prompt_library.get_brand_prompts(brand_name) |
|
|
|
|
|
if not brand_info: |
|
|
|
|
|
return 0.50, 0.30, 0.20 |
|
|
|
|
|
|
|
|
if brand_info.get('visual_distinctive', False): |
|
|
|
|
|
visual_weight = 0.65 |
|
|
text_weight = 0.20 |
|
|
ocr_weight = 0.15 |
|
|
elif brand_info.get('text_prominent', False): |
|
|
|
|
|
visual_weight = 0.30 |
|
|
text_weight = 0.30 |
|
|
ocr_weight = 0.40 |
|
|
else: |
|
|
|
|
|
visual_weight = 0.50 |
|
|
text_weight = 0.30 |
|
|
ocr_weight = 0.20 |
|
|
|
|
|
|
|
|
|
|
|
if visual_score > 0.7: |
|
|
boost = 0.10 |
|
|
visual_weight += boost |
|
|
text_weight -= boost * 0.5 |
|
|
ocr_weight -= boost * 0.5 |
|
|
|
|
|
|
|
|
if ocr_conf > 0.85: |
|
|
boost = 0.10 |
|
|
ocr_weight += boost |
|
|
visual_weight -= boost * 0.6 |
|
|
text_weight -= boost * 0.4 |
|
|
|
|
|
|
|
|
if text_score > 0.80: |
|
|
boost = 0.08 |
|
|
text_weight += boost |
|
|
visual_weight -= boost * 0.5 |
|
|
ocr_weight -= boost * 0.5 |
|
|
|
|
|
|
|
|
total = visual_weight + text_weight + ocr_weight |
|
|
return visual_weight / total, text_weight / total, ocr_weight / total |
|
|
|
|
|
def _multi_scale_visual_matching(self, image_region: Image.Image, |
|
|
initial_scores: Dict[str, float]) -> Dict[str, float]: |
|
|
""" |
|
|
Apply multi-scale matching to improve robustness |
|
|
|
|
|
Args: |
|
|
image_region: Image region to analyze |
|
|
initial_scores: Initial brand scores from single-scale matching |
|
|
|
|
|
Returns: |
|
|
Updated brand scores with multi-scale matching |
|
|
""" |
|
|
scales = [0.8, 1.0, 1.2] |
|
|
multi_scale_scores = {brand: [] for brand in initial_scores.keys()} |
|
|
|
|
|
for scale in scales: |
|
|
|
|
|
new_width = int(image_region.width * scale) |
|
|
new_height = int(image_region.height * scale) |
|
|
|
|
|
|
|
|
if new_width < 50 or new_height < 50: |
|
|
continue |
|
|
|
|
|
try: |
|
|
scaled_img = image_region.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
for brand_name, brand_info in self.flat_brands.items(): |
|
|
|
|
|
best_context = self._match_region_to_brand_context( |
|
|
'bag_panel', |
|
|
brand_info.get('region_contexts', []) |
|
|
) |
|
|
|
|
|
if best_context and best_context in brand_info.get('openclip_prompts', {}): |
|
|
prompts = brand_info['openclip_prompts'][best_context] |
|
|
visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) |
|
|
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
else: |
|
|
prompts = brand_info.get('strong_cues', [])[:3] |
|
|
visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts) |
|
|
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
|
|
|
multi_scale_scores[brand_name].append(avg_score) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
final_scores = {} |
|
|
for brand_name, scores in multi_scale_scores.items(): |
|
|
if scores: |
|
|
final_scores[brand_name] = max(scores) |
|
|
else: |
|
|
final_scores[brand_name] = initial_scores.get(brand_name, 0.0) |
|
|
|
|
|
return final_scores |
|
|
|
|
|
def scan_full_image_for_brands(self, full_image: Image.Image, |
|
|
exclude_bboxes: List[List[int]] = None, |
|
|
saliency_regions: List[Dict] = None) -> List[Tuple[str, float, List[int]]]: |
|
|
""" |
|
|
智能全圖品牌掃描 - 性能優化版本 |
|
|
使用預篩選和智能區域選擇大幅減少檢測時間 |
|
|
|
|
|
Args: |
|
|
full_image: PIL Image (full image) |
|
|
exclude_bboxes: List of bboxes to exclude (already detected) |
|
|
saliency_regions: Saliency detection results for smart region selection |
|
|
|
|
|
Returns: |
|
|
List of (brand_name, confidence, bbox) tuples |
|
|
""" |
|
|
if exclude_bboxes is None: |
|
|
exclude_bboxes = [] |
|
|
|
|
|
detected_brands = {} |
|
|
img_width, img_height = full_image.size |
|
|
|
|
|
|
|
|
likely_brands = self.optimizer.quick_brand_prescreening(full_image) |
|
|
print(f" Quick prescreening found {len(likely_brands)} potential brands") |
|
|
|
|
|
|
|
|
regions_to_scan = self.optimizer.smart_region_selection(full_image, saliency_regions or []) |
|
|
print(f" Scanning {len(regions_to_scan)} intelligent regions") |
|
|
|
|
|
|
|
|
for region_bbox in regions_to_scan: |
|
|
x1, y1, x2, y2 = region_bbox |
|
|
|
|
|
|
|
|
if self._bbox_overlap(list(region_bbox), exclude_bboxes): |
|
|
continue |
|
|
|
|
|
|
|
|
region = full_image.crop(region_bbox) |
|
|
|
|
|
|
|
|
for brand_name in likely_brands: |
|
|
brand_info = self.flat_brands.get(brand_name) |
|
|
if not brand_info: |
|
|
continue |
|
|
|
|
|
|
|
|
strong_cues = brand_info.get('strong_cues', [])[:5] |
|
|
if not strong_cues: |
|
|
continue |
|
|
|
|
|
visual_scores = self.clip_manager.classify_zero_shot(region, strong_cues) |
|
|
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0 |
|
|
|
|
|
|
|
|
ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True) |
|
|
boosted_score = self.optimizer.compute_brand_confidence_boost( |
|
|
brand_name, ocr_results, avg_score |
|
|
) |
|
|
|
|
|
|
|
|
if boosted_score > 0.08: |
|
|
|
|
|
if brand_name not in detected_brands or boosted_score > detected_brands[brand_name][0]: |
|
|
detected_brands[brand_name] = (boosted_score, list(region_bbox)) |
|
|
|
|
|
|
|
|
final_brands = [ |
|
|
(brand_name, confidence, bbox) |
|
|
for brand_name, (confidence, bbox) in detected_brands.items() |
|
|
] |
|
|
|
|
|
|
|
|
final_brands.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
return final_brands[:5] |
|
|
|
|
|
def _bbox_overlap(self, bbox1: List[int], bbox_list: List[List[int]]) -> bool: |
|
|
"""Check if bbox1 overlaps significantly with any bbox in bbox_list""" |
|
|
if not bbox_list: |
|
|
return False |
|
|
|
|
|
x1_1, y1_1, x2_1, y2_1 = bbox1 |
|
|
|
|
|
for bbox2 in bbox_list: |
|
|
if bbox2 is None: |
|
|
continue |
|
|
|
|
|
x1_2, y1_2, x2_2, y2_2 = bbox2 |
|
|
|
|
|
|
|
|
x_left = max(x1_1, x1_2) |
|
|
y_top = max(y1_1, y1_2) |
|
|
x_right = min(x2_1, x2_2) |
|
|
y_bottom = min(y2_1, y2_2) |
|
|
|
|
|
if x_right < x_left or y_bottom < y_top: |
|
|
continue |
|
|
|
|
|
intersection_area = (x_right - x_left) * (y_bottom - y_top) |
|
|
bbox1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
|
|
|
|
|
|
|
|
if intersection_area / bbox1_area > 0.3: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
print("✓ BrandRecognitionManager (with full-image scan for commercial use) defined") |
|
|
|