File size: 17,302 Bytes
6a3bd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import torch
import math
from PIL import Image
from typing import Dict, List, Tuple
from rapidfuzz import fuzz
from prompt_library_manager import PromptLibraryManager
from brand_detection_optimizer import BrandDetectionOptimizer

class BrandRecognitionManager:
    """Multi-modal brand recognition with detailed prompts (Visual + Text)"""

    def __init__(self, clip_manager, ocr_manager, prompt_library=None):
        self.clip_manager = clip_manager
        self.ocr_manager = ocr_manager
        self.prompt_library = prompt_library
        self.flat_brands = prompt_library.get_all_brands()

        # Initialize optimizer for smart brand detection
        self.optimizer = BrandDetectionOptimizer(clip_manager, ocr_manager, prompt_library)

        print(f"✓ Brand Recognition Manager loaded with {len(self.flat_brands)} brands (with optimizer)")

    def recognize_brand(self, image_region: Image.Image, full_image: Image.Image,
                       region_bbox: List[int] = None) -> List[Tuple[str, float, List[int]]]:
        """Recognize brands using detailed context-aware prompts

        Args:
            image_region: Cropped region containing potential brand
            full_image: Full image for OCR
            region_bbox: Bounding box [x1, y1, x2, y2] for visualization

        Returns:
            List of (brand_name, confidence, bbox) tuples
        """

        # Step 1: Classify region context
        region_context = self._classify_region_context(image_region)
        print(f"  [DEBUG] Region context classified as: {region_context}")

        # Step 2: Use context-specific OpenCLIP prompts
        brand_scores = {}

        for brand_name, brand_info in self.flat_brands.items():
            # Get best matching context for this brand
            best_context = self._match_region_to_brand_context(region_context, brand_info['region_contexts'])

            if best_context and best_context in brand_info['openclip_prompts']:
                # Use context-specific prompts
                prompts = brand_info['openclip_prompts'][best_context]
                visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts)

                # Average scores from all prompts
                avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
            else:
                # Fallback to strong cues
                prompts = brand_info['strong_cues'][:5]  # Top 5 strong cues
                visual_scores = self.clip_manager.classify_zero_shot(image_region, prompts)
                avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0

            brand_scores[brand_name] = avg_score

        # Step 2.5: Multi-scale visual matching for better robustness
        brand_scores = self._multi_scale_visual_matching(image_region, brand_scores)

        # Step 3: OCR text matching with brand-optimized preprocessing
        ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True)
        text_matches = self._fuzzy_text_matching(ocr_results)

        print(f"  [DEBUG] OCR found {len(ocr_results)} text regions")
        if text_matches:
            print(f"  [DEBUG] OCR brand matches: {text_matches}")

        # Step 4: Adaptive weighted fusion (dynamic weights per brand)
        final_scores = {}
        for brand_name in self.flat_brands.keys():
            visual_score = brand_scores.get(brand_name, 0.0)
            text_score, ocr_conf = text_matches.get(brand_name, (0.0, 0.0))

            # Calculate adaptive weights based on brand characteristics
            visual_weight, text_weight, ocr_weight = self._calculate_adaptive_weights(
                brand_name, visual_score, text_score, ocr_conf
            )

            # Weighted fusion with adaptive weights
            final_score = (
                visual_weight * self._scale_visual(visual_score) +
                text_weight * text_score +
                ocr_weight * ocr_conf
            )
            final_scores[brand_name] = final_score

        sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:5]
        print(f"  [DEBUG] Top 5 brand scores:")
        for brand, score in sorted_scores:
            print(f"    {brand}: {score:.4f} (visual={brand_scores.get(brand, 0):.4f}, text={text_matches.get(brand, (0, 0))[0]:.4f})")

        # Return confident matches with bounding boxes
        confident_brands = []
        for brand_name, score in final_scores.items():
            if score > 0.10:
                confident_brands.append((brand_name, score, region_bbox))
                print(f"  [DEBUG] ✓ Brand detected: {brand_name} (confidence: {score:.4f})")

        confident_brands.sort(key=lambda x: x[1], reverse=True)

        if not confident_brands:
            print(f"  [DEBUG] ✗ No brands passed threshold 0.10")

        return confident_brands

    def _classify_region_context(self, image_region: Image.Image) -> str:
        """Classify what type of region this is (bag_panel, shoe_side, etc.)"""
        context_labels = [
            'bag panel with pattern',
            'luggage surface with branding',
            'luxury trunk with monogram pattern',
            'vintage travel trunk with hardware',
            'shoe side view',
            'device back cover',
            'apparel chest area',
            'belt buckle',
            'storefront sign',
            'product tag or label',
            'wallet surface',
            'perfume bottle',
            'watch dial or face',
            'car front grille',
            'laptop lid'
        ]

        scores = self.clip_manager.classify_zero_shot(image_region, context_labels)

        # Map to simplified contexts
        context_mapping = {
            'bag panel with pattern': 'bag_panel',
            'luggage surface with branding': 'luggage_surface',
            'luxury trunk with monogram pattern': 'trunk_body',
            'vintage travel trunk with hardware': 'trunk_body',
            'shoe side view': 'shoe_side',
            'device back cover': 'device_back',
            'apparel chest area': 'apparel_chest',
            'belt buckle': 'belt_buckle',
            'storefront sign': 'storefront',
            'product tag or label': 'product_tag',
            'wallet surface': 'wallet',
            'perfume bottle': 'perfume_bottle',
            'watch dial or face': 'watch_dial',
            'car front grille': 'car_front',
            'laptop lid': 'laptop_lid'
        }

        top_context = max(scores.items(), key=lambda x: x[1])[0]
        return context_mapping.get(top_context, 'unknown')

    def _match_region_to_brand_context(self, region_context: str, brand_contexts: List[str]) -> str:
        """Match detected region context to brand's available contexts"""
        if region_context in brand_contexts:
            return region_context
        # Fuzzy matching
        for brand_context in brand_contexts:
            if region_context.split('_')[0] in brand_context:
                return brand_context
        return None

    def _fuzzy_text_matching(self, ocr_results: List[Dict]) -> Dict[str, Tuple[float, float]]:
        """Fuzzy text matching using brand aliases (optimized for logo text)"""
        matches = {}

        for ocr_item in ocr_results:
            text = ocr_item['text']
            conf = ocr_item['confidence']

            for brand_name, brand_info in self.flat_brands.items():
                # Check all aliases
                all_names = [brand_name] + brand_info.get('aliases', [])

                for alias in all_names:
                    ratio = fuzz.ratio(text, alias) / 100.0
                    if ratio > 0.70:  # Lowered threshold for better recall
                        if brand_name not in matches or ratio > matches[brand_name][0]:
                            matches[brand_name] = (ratio, conf)

        return matches

    def _scale_visual(self, score: float) -> float:
        """Scale visual score using sigmoid"""
        return 1 / (1 + math.exp(-10 * (score - 0.5)))

    def _calculate_adaptive_weights(self, brand_name: str, visual_score: float,
                                     text_score: float, ocr_conf: float) -> tuple:
        """
        Calculate adaptive weights based on brand characteristics and signal strengths

        Args:
            brand_name: Name of the brand
            visual_score: Visual similarity score
            text_score: Text matching score
            ocr_conf: OCR confidence

        Returns:
            Tuple of (visual_weight, text_weight, ocr_weight)
        """
        brand_info = self.prompt_library.get_brand_prompts(brand_name)

        if not brand_info:
            # Default balanced weights
            return 0.50, 0.30, 0.20

        # Base weights based on brand characteristics
        if brand_info.get('visual_distinctive', False):
            # Visually distinctive brands (LV, Burberry)
            visual_weight = 0.65
            text_weight = 0.20
            ocr_weight = 0.15
        elif brand_info.get('text_prominent', False):
            # Text-prominent brands (Nike, Adidas)
            visual_weight = 0.30
            text_weight = 0.30
            ocr_weight = 0.40
        else:
            # Balanced for general brands
            visual_weight = 0.50
            text_weight = 0.30
            ocr_weight = 0.20

        # Dynamic adjustment based on signal strength
        # If visual signal is very strong, boost its weight
        if visual_score > 0.7:
            boost = 0.10
            visual_weight += boost
            text_weight -= boost * 0.5
            ocr_weight -= boost * 0.5

        # If OCR has very high confidence, boost its weight
        if ocr_conf > 0.85:
            boost = 0.10
            ocr_weight += boost
            visual_weight -= boost * 0.6
            text_weight -= boost * 0.4

        # If text match is very strong, boost its weight
        if text_score > 0.80:
            boost = 0.08
            text_weight += boost
            visual_weight -= boost * 0.5
            ocr_weight -= boost * 0.5

        # Normalize weights to sum to 1
        total = visual_weight + text_weight + ocr_weight
        return visual_weight / total, text_weight / total, ocr_weight / total

    def _multi_scale_visual_matching(self, image_region: Image.Image,
                                      initial_scores: Dict[str, float]) -> Dict[str, float]:
        """
        Apply multi-scale matching to improve robustness

        Args:
            image_region: Image region to analyze
            initial_scores: Initial brand scores from single-scale matching

        Returns:
            Updated brand scores with multi-scale matching
        """
        scales = [0.8, 1.0, 1.2]  # Three scales
        multi_scale_scores = {brand: [] for brand in initial_scores.keys()}

        for scale in scales:
            # Resize image
            new_width = int(image_region.width * scale)
            new_height = int(image_region.height * scale)

            # Ensure minimum size
            if new_width < 50 or new_height < 50:
                continue

            try:
                scaled_img = image_region.resize((new_width, new_height), Image.Resampling.LANCZOS)

                # Re-run classification on each brand's prompts
                for brand_name, brand_info in self.flat_brands.items():
                    # Get context-specific prompts
                    best_context = self._match_region_to_brand_context(
                        'bag_panel',  # Default context, ideally should be passed as parameter
                        brand_info.get('region_contexts', [])
                    )

                    if best_context and best_context in brand_info.get('openclip_prompts', {}):
                        prompts = brand_info['openclip_prompts'][best_context]
                        visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts)
                        avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0
                    else:
                        prompts = brand_info.get('strong_cues', [])[:3]
                        visual_scores = self.clip_manager.classify_zero_shot(scaled_img, prompts)
                        avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0

                    multi_scale_scores[brand_name].append(avg_score)

            except Exception as e:
                # Skip this scale if error occurs
                continue

        # Aggregate multi-scale scores (use max score across scales)
        final_scores = {}
        for brand_name, scores in multi_scale_scores.items():
            if scores:
                final_scores[brand_name] = max(scores)
            else:
                final_scores[brand_name] = initial_scores.get(brand_name, 0.0)

        return final_scores

    def scan_full_image_for_brands(self, full_image: Image.Image,
                                    exclude_bboxes: List[List[int]] = None,
                                    saliency_regions: List[Dict] = None) -> List[Tuple[str, float, List[int]]]:
        """
        智能全圖品牌掃描 - 性能優化版本
        使用預篩選和智能區域選擇大幅減少檢測時間

        Args:
            full_image: PIL Image (full image)
            exclude_bboxes: List of bboxes to exclude (already detected)
            saliency_regions: Saliency detection results for smart region selection

        Returns:
            List of (brand_name, confidence, bbox) tuples
        """
        if exclude_bboxes is None:
            exclude_bboxes = []

        detected_brands = {}  # brand_name -> (confidence, bbox)
        img_width, img_height = full_image.size

        # OPTIMIZATION 1: 快速品牌預篩選
        likely_brands = self.optimizer.quick_brand_prescreening(full_image)
        print(f"      Quick prescreening found {len(likely_brands)} potential brands")

        # OPTIMIZATION 2: 智能區域選擇(只掃描有意義的區域)
        regions_to_scan = self.optimizer.smart_region_selection(full_image, saliency_regions or [])
        print(f"      Scanning {len(regions_to_scan)} intelligent regions")

        # 掃描選定的區域
        for region_bbox in regions_to_scan:
            x1, y1, x2, y2 = region_bbox

            # 跳過已檢測區域
            if self._bbox_overlap(list(region_bbox), exclude_bboxes):
                continue

            # 提取區域
            region = full_image.crop(region_bbox)

            # 只檢測預篩選的品牌(而非所有20+品牌)
            for brand_name in likely_brands:
                brand_info = self.flat_brands.get(brand_name)
                if not brand_info:
                    continue

                # only use strong_cues
                strong_cues = brand_info.get('strong_cues', [])[:5]  # Top 5
                if not strong_cues:
                    continue

                visual_scores = self.clip_manager.classify_zero_shot(region, strong_cues)
                avg_score = sum(visual_scores.values()) / len(visual_scores) if visual_scores else 0.0

                # OCR 增強
                ocr_results = self.ocr_manager.extract_text(full_image, use_brand_preprocessing=True)
                boosted_score = self.optimizer.compute_brand_confidence_boost(
                    brand_name, ocr_results, avg_score
                )

                # 極度寬鬆的閾值以最大化檢測率
                if boosted_score > 0.08:  # 降低到 0.08
                    # 更新最佳結果
                    if brand_name not in detected_brands or boosted_score > detected_brands[brand_name][0]:
                        detected_brands[brand_name] = (boosted_score, list(region_bbox))

        # 轉換為列表格式
        final_brands = [
            (brand_name, confidence, bbox)
            for brand_name, (confidence, bbox) in detected_brands.items()
        ]

        # 按信心度排序
        final_brands.sort(key=lambda x: x[1], reverse=True)

        return final_brands[:5]  # 返回前5個

    def _bbox_overlap(self, bbox1: List[int], bbox_list: List[List[int]]) -> bool:
        """Check if bbox1 overlaps significantly with any bbox in bbox_list"""
        if not bbox_list:
            return False

        x1_1, y1_1, x2_1, y2_1 = bbox1

        for bbox2 in bbox_list:
            if bbox2 is None:
                continue

            x1_2, y1_2, x2_2, y2_2 = bbox2

            # Calculate intersection
            x_left = max(x1_1, x1_2)
            y_top = max(y1_1, y1_2)
            x_right = min(x2_1, x2_2)
            y_bottom = min(y2_1, y2_2)

            if x_right < x_left or y_bottom < y_top:
                continue

            intersection_area = (x_right - x_left) * (y_bottom - y_top)
            bbox1_area = (x2_1 - x1_1) * (y2_1 - y1_1)

            # 如果重疊超過 30%,視為重疊
            if intersection_area / bbox1_area > 0.3:
                return True

        return False

print("✓ BrandRecognitionManager (with full-image scan for commercial use) defined")