Pixcribe / brand_recognition_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import torch
import math
from PIL import Image
from typing import Dict, List, Tuple
from rapidfuzz import fuzz
from prompt_library_manager import PromptLibraryManager
from brand_detection_optimizer import BrandDetectionOptimizer
class BrandRecognitionManager:
"""Multi-modal brand recognition with detailed prompts (Visual + Text)"""
def __init__(self, clip_manager, ocr_manager, prompt_library=None):
self.clip_manager = clip_manager
self.ocr_manager = ocr_manager
self.prompt_library = prompt_library
self.flat_brands = prompt_library.get_all_brands()
# Initialize optimizer for smart brand detection
self.optimizer = BrandDetectionOptimizer(clip_manager, ocr_manager, prompt_library)
print(f"✓ Brand Recognition Manager loaded with {len(self.flat_brands)} brands (with optimizer)")
def recognize_brand(self, image_region: Image.Image, full_image: Image.Image,
region_bbox: List[int] = None) -> List[Tuple[str, float, List[int]]]:
"""Recognize brands using detailed context-aware prompts
Args:
image_region: Cropped region containing potential brand
full_image: Full image for OCR
region_bbox: Bounding box [x1, y1, x2, y2] for visualization
Returns:
List of (brand_name, confidence, bbox) tuples
"""
# Step 1: Classify region context
region_context = self._classify_region_context(image_region)
print(f" [DEBUG] Region context classified as: {region_context}")
# Step 2: Use context-specific OpenCLIP prompts
brand_scores = {}
for brand_name, brand_info in self.flat_brands.items():
# Get best matching context for this brand
best_context = self._match_region_to_brand_context(region_context, brand_info['region_contexts'])
if best_context and best_context in brand_info['openclip_prompts']:
# Use context-specific prompts
prompts = brand_info['openclip_prompts'][best_context]
visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts)
# Average scores from all prompts
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
else:
# Fallback to strong cues
prompts = brand_info['strong_cues'][:5] # Top 5 strong cues
visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts)
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
brand_scores[brand_name] = avg_score
# Step 2.5: Multi-scale visual matching for better robustness
brand_scores = self._multi_scale_visual_matching(image_region, brand_scores)
# Step 3: OCR text matching with brand-optimized preprocessing
ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True)
text_matches = self._fuzzy_text_matching(ocr_results)
print(f" [DEBUG] OCR found {len(ocr_results)} text regions")
if text_matches:
print(f" [DEBUG] OCR brand matches: {text_matches}")
# Step 4: Adaptive weighted fusion (dynamic weights per brand)
final_scores = {}
for brand_name in self.flat_brands.keys():
visual_score = brand_scores.get(brand_name, 0.0)
text_score, ocr_conf = text_matches.get(brand_name, (0.0, 0.0))
# Calculate adaptive weights based on brand characteristics
visual_weight, text_weight, ocr_weight = self._calculate_adaptive_weights(
brand_name, visual_score, text_score, ocr_conf
)
# Weighted fusion with adaptive weights
final_score = (
visual_weight * self._scale_visual(visual_score) +
text_weight * text_score +
ocr_weight * ocr_conf
)
final_scores[brand_name] = final_score
sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:5]
print(f" [DEBUG] Top 5 brand scores:")
for brand, score in sorted_scores:
print(f" {brand}: {score:.4f} (visual={brand_scores.get(brand, 0):.4f}, text={text_matches.get(brand, (0, 0))[0]:.4f})")
# Return confident matches with bounding boxes
confident_brands = []
for brand_name, score in final_scores.items():
if score > 0.10:
confident_brands.append((brand_name, score, region_bbox))
print(f" [DEBUG] ✓ Brand detected: {brand_name} (confidence: {score:.4f})")
confident_brands.sort(key=lambda x: x[1], reverse=True)
if not confident_brands:
print(f" [DEBUG] ✗ No brands passed threshold 0.10")
return confident_brands
def _classify_region_context(self, image_region: Image.Image) -> str:
"""Classify what type of region this is (bag_panel, shoe_side, etc.)"""
context_labels = [
'bag panel with pattern',
'luggage surface with branding',
'luxury trunk with monogram pattern',
'vintage travel trunk with hardware',
'shoe side view',
'device back cover',
'apparel chest area',
'belt buckle',
'storefront sign',
'product tag or label',
'wallet surface',
'perfume bottle',
'watch dial or face',
'car front grille',
'laptop lid'
]
scores = self.clip_manager.classify_zero_shot(image_region, context_labels)
# Map to simplified contexts
context_mapping = {
'bag panel with pattern': 'bag_panel',
'luggage surface with branding': 'luggage_surface',
'luxury trunk with monogram pattern': 'trunk_body',
'vintage travel trunk with hardware': 'trunk_body',
'shoe side view': 'shoe_side',
'device back cover': 'device_back',
'apparel chest area': 'apparel_chest',
'belt buckle': 'belt_buckle',
'storefront sign': 'storefront',
'product tag or label': 'product_tag',
'wallet surface': 'wallet',
'perfume bottle': 'perfume_bottle',
'watch dial or face': 'watch_dial',
'car front grille': 'car_front',
'laptop lid': 'laptop_lid'
}
top_context = max(scores.items(), key=lambda x: x[1])[0]
return context_mapping.get(top_context, 'unknown')
def _match_region_to_brand_context(self, region_context: str, brand_contexts: List[str]) -> str:
"""Match detected region context to brand's available contexts"""
if region_context in brand_contexts:
return region_context
# Fuzzy matching
for brand_context in brand_contexts:
if region_context.split('_')[0] in brand_context:
return brand_context
return None
def _fuzzy_text_matching(self, ocr_results: List[Dict]) -> Dict[str, Tuple[float, float]]:
"""Fuzzy text matching using brand aliases (optimized for logo text)"""
matches = {}
for ocr_item in ocr_results:
text = ocr_item['text']
conf = ocr_item['confidence']
for brand_name, brand_info in self.flat_brands.items():
# Check all aliases
all_names = [brand_name] + brand_info.get('aliases', [])
for alias in all_names:
ratio = fuzz.ratio(text, alias) / 100.0
if ratio > 0.70: # Lowered threshold for better recall
if brand_name not in matches or ratio > matches[brand_name][0]:
matches[brand_name] = (ratio, conf)
return matches
def _scale_visual(self, score: float) -> float:
"""Scale visual score using sigmoid"""
return 1 / (1 + math.exp(-10 * (score - 0.5)))
def _calculate_adaptive_weights(self, brand_name: str, visual_score: float,
text_score: float, ocr_conf: float) -> tuple:
"""
Calculate adaptive weights based on brand characteristics and signal strengths
Args:
brand_name: Name of the brand
visual_score: Visual similarity score
text_score: Text matching score
ocr_conf: OCR confidence
Returns:
Tuple of (visual_weight, text_weight, ocr_weight)
"""
brand_info = self.prompt_library.get_brand_prompts(brand_name)
if not brand_info:
# Default balanced weights
return 0.50, 0.30, 0.20
# Base weights based on brand characteristics
if brand_info.get('visual_distinctive', False):
# Visually distinctive brands (LV, Burberry)
visual_weight = 0.65
text_weight = 0.20
ocr_weight = 0.15
elif brand_info.get('text_prominent', False):
# Text-prominent brands (Nike, Adidas)
visual_weight = 0.30
text_weight = 0.30
ocr_weight = 0.40
else:
# Balanced for general brands
visual_weight = 0.50
text_weight = 0.30
ocr_weight = 0.20
# Dynamic adjustment based on signal strength
# If visual signal is very strong, boost its weight
if visual_score > 0.7:
boost = 0.10
visual_weight += boost
text_weight -= boost * 0.5
ocr_weight -= boost * 0.5
# If OCR has very high confidence, boost its weight
if ocr_conf > 0.85:
boost = 0.10
ocr_weight += boost
visual_weight -= boost * 0.6
text_weight -= boost * 0.4
# If text match is very strong, boost its weight
if text_score > 0.80:
boost = 0.08
text_weight += boost
visual_weight -= boost * 0.5
ocr_weight -= boost * 0.5
# Normalize weights to sum to 1
total = visual_weight + text_weight + ocr_weight
return visual_weight / total, text_weight / total, ocr_weight / total
def _multi_scale_visual_matching(self, image_region: Image.Image,
initial_scores: Dict[str, float]) -> Dict[str, float]:
"""
Apply multi-scale matching to improve robustness
Args:
image_region: Image region to analyze
initial_scores: Initial brand scores from single-scale matching
Returns:
Updated brand scores with multi-scale matching
"""
scales = [0.8, 1.0, 1.2] # Three scales
multi_scale_scores = {brand: [] for brand in initial_scores.keys()}
for scale in scales:
# Resize image
new_width = int(image_region.width * scale)
new_height = int(image_region.height * scale)
# Ensure minimum size
if new_width < 50 or new_height < 50:
continue
try:
scaled_img = image_region.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Re-run classification on each brand's prompts
for brand_name, brand_info in self.flat_brands.items():
# Get context-specific prompts
best_context = self._match_region_to_brand_context(
'bag_panel', # Default context, ideally should be passed as parameter
brand_info.get('region_contexts', [])
)
if best_context and best_context in brand_info.get('openclip_prompts', {}):
prompts = brand_info['openclip_prompts'][best_context]
visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts)
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
else:
prompts = brand_info.get('strong_cues', [])[:3]
visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts)
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
multi_scale_scores[brand_name].append(avg_score)
except Exception as e:
# Skip this scale if error occurs
continue
# Aggregate multi-scale scores (use max score across scales)
final_scores = {}
for brand_name, scores in multi_scale_scores.items():
if scores:
final_scores[brand_name] = max(scores)
else:
final_scores[brand_name] = initial_scores.get(brand_name, 0.0)
return final_scores
def scan_full_image_for_brands(self, full_image: Image.Image,
exclude_bboxes: List[List[int]] = None,
saliency_regions: List[Dict] = None) -> List[Tuple[str, float, List[int]]]:
"""
智能全圖品牌掃描 - 性能優化版本
使用預篩選和智能區域選擇大幅減少檢測時間
Args:
full_image: PIL Image (full image)
exclude_bboxes: List of bboxes to exclude (already detected)
saliency_regions: Saliency detection results for smart region selection
Returns:
List of (brand_name, confidence, bbox) tuples
"""
if exclude_bboxes is None:
exclude_bboxes = []
detected_brands = {} # brand_name -> (confidence, bbox)
img_width, img_height = full_image.size
# OPTIMIZATION 1: 快速品牌預篩選
likely_brands = self.optimizer.quick_brand_prescreening(full_image)
print(f" Quick prescreening found {len(likely_brands)} potential brands")
# OPTIMIZATION 2: 智能區域選擇(只掃描有意義的區域)
regions_to_scan = self.optimizer.smart_region_selection(full_image, saliency_regions or [])
print(f" Scanning {len(regions_to_scan)} intelligent regions")
# 掃描選定的區域
for region_bbox in regions_to_scan:
x1, y1, x2, y2 = region_bbox
# 跳過已檢測區域
if self._bbox_overlap(list(region_bbox), exclude_bboxes):
continue
# 提取區域
region = full_image.crop(region_bbox)
# 只檢測預篩選的品牌(而非所有20+品牌)
for brand_name in likely_brands:
brand_info = self.flat_brands.get(brand_name)
if not brand_info:
continue
# only use strong_cues
strong_cues = brand_info.get('strong_cues', [])[:5] # Top 5
if not strong_cues:
continue
visual_scores = self.clip_manager.classify_zero_shot(region, strong_cues)
avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
# OCR 增強
ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True)
boosted_score = self.optimizer.compute_brand_confidence_boost(
brand_name, ocr_results, avg_score
)
# 極度寬鬆的閾值以最大化檢測率
if boosted_score > 0.08: # 降低到 0.08
# 更新最佳結果
if brand_name not in detected_brands or boosted_score > detected_brands[brand_name][0]:
detected_brands[brand_name] = (boosted_score, list(region_bbox))
# 轉換為列表格式
final_brands = [
(brand_name, confidence, bbox)
for brand_name, (confidence, bbox) in detected_brands.items()
]
# 按信心度排序
final_brands.sort(key=lambda x: x[1], reverse=True)
return final_brands[:5] # 返回前5個
def _bbox_overlap(self, bbox1: List[int], bbox_list: List[List[int]]) -> bool:
"""Check if bbox1 overlaps significantly with any bbox in bbox_list"""
if not bbox_list:
return False
x1_1, y1_1, x2_1, y2_1 = bbox1
for bbox2 in bbox_list:
if bbox2 is None:
continue
x1_2, y1_2, x2_2, y2_2 = bbox2
# Calculate intersection
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
if x_right < x_left or y_bottom < y_top:
continue
intersection_area = (x_right - x_left) * (y_bottom - y_top)
bbox1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
# 如果重疊超過 30%,視為重疊
if intersection_area / bbox1_area > 0.3:
return True
return False
print("✓ BrandRecognitionManager (with full-image scan for commercial use) defined")