|
|
import torch |
|
|
import json |
|
|
import re |
|
|
from PIL import Image |
|
|
from typing import List, Dict, Tuple |
|
|
from datetime import datetime |
|
|
from caption_generation_manager import CaptionGenerationManager |
|
|
|
|
|
class BrandVerificationManager: |
|
|
"""VLM-based brand verification and three-way voting system""" |
|
|
|
|
|
def __init__(self, caption_generator: CaptionGenerationManager = None): |
|
|
""" |
|
|
Args: |
|
|
caption_generator: CaptionGenerationManager instance for VLM access |
|
|
""" |
|
|
if caption_generator is None: |
|
|
caption_generator = CaptionGenerationManager() |
|
|
|
|
|
self.caption_generator = caption_generator |
|
|
|
|
|
|
|
|
self.confidence_map = { |
|
|
'high': 0.9, |
|
|
'medium': 0.7, |
|
|
'low': 0.5, |
|
|
'very high': 0.95, |
|
|
'very low': 0.3 |
|
|
} |
|
|
|
|
|
print("✓ Brand Verification Manager initialized with VLM") |
|
|
|
|
|
def verify_brands(self, image: Image.Image, detected_brands: List[Tuple[str, float, list]]) -> Dict: |
|
|
""" |
|
|
Use VLM to verify detected brands |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
detected_brands: List of (brand_name, confidence, bbox) tuples |
|
|
|
|
|
Returns: |
|
|
Dictionary with verification results |
|
|
""" |
|
|
if not detected_brands: |
|
|
return { |
|
|
'verified_brands': [], |
|
|
'false_positives': [], |
|
|
'additional_brands': [], |
|
|
'confidence': 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
brand_list = ', '.join([f"{brand[0]} (confidence: {brand[1]:.2f})" |
|
|
for brand in detected_brands[:3]]) |
|
|
|
|
|
verification_prompt = f"""Analyze this image carefully. Our computer vision system detected the following brands: {brand_list}. |
|
|
|
|
|
Please verify each brand identification: |
|
|
|
|
|
1. Are these brand identifications correct based on visible logos, patterns, text, or distinctive features? |
|
|
2. If incorrect, what brands do you actually see (if any)? |
|
|
3. Describe the visual evidence (logo shape, text, pattern, color scheme, hardware) that supports your conclusion. |
|
|
|
|
|
Respond in JSON format: |
|
|
{{ |
|
|
"verified_brands": [ |
|
|
{{"name": "Brand Name", "confidence": "high/medium/low", "evidence": "description of visual evidence"}} |
|
|
], |
|
|
"false_positives": ["brand names that were incorrectly detected"], |
|
|
"additional_brands": ["brands we missed but you can see"] |
|
|
}} |
|
|
|
|
|
IMPORTANT: Only include brands you can clearly identify with visual evidence. If unsure, use "low" confidence.""" |
|
|
|
|
|
|
|
|
try: |
|
|
response = self._generate_vlm_response(image, verification_prompt) |
|
|
parsed_result = self._parse_verification_response(response) |
|
|
return parsed_result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"VLM verification error: {e}") |
|
|
|
|
|
return { |
|
|
'verified_brands': [ |
|
|
{'name': brand[0], 'confidence': 'medium', 'evidence': 'VLM verification failed'} |
|
|
for brand in detected_brands |
|
|
], |
|
|
'false_positives': [], |
|
|
'additional_brands': [] |
|
|
} |
|
|
|
|
|
def three_way_voting(self, openclip_brands: List[Tuple], ocr_brands: Dict, |
|
|
vlm_result: Dict) -> List[Tuple[str, float, list]]: |
|
|
""" |
|
|
Three-way voting: OpenCLIP vs OCR vs VLM |
|
|
|
|
|
Args: |
|
|
openclip_brands: List of (brand_name, confidence, bbox) from OpenCLIP |
|
|
ocr_brands: Dict of {brand_name: (text_score, ocr_conf)} from OCR |
|
|
vlm_result: Verification result from VLM |
|
|
|
|
|
Returns: |
|
|
List of (brand_name, final_confidence, bbox) tuples |
|
|
""" |
|
|
votes = {} |
|
|
confidence_scores = {} |
|
|
|
|
|
|
|
|
for brand_name, confidence, bbox in openclip_brands: |
|
|
if brand_name not in votes: |
|
|
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': bbox} |
|
|
confidence_scores[brand_name] = [] |
|
|
|
|
|
votes[brand_name]['votes'] += 1 |
|
|
votes[brand_name]['sources'].append('openclip') |
|
|
confidence_scores[brand_name].append(('openclip', confidence * 0.8)) |
|
|
|
|
|
|
|
|
for brand_name, (text_score, ocr_conf) in ocr_brands.items(): |
|
|
if brand_name not in votes: |
|
|
|
|
|
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None} |
|
|
confidence_scores[brand_name] = [] |
|
|
|
|
|
votes[brand_name]['votes'] += 1 |
|
|
votes[brand_name]['sources'].append('ocr') |
|
|
combined_ocr_score = (text_score + ocr_conf) / 2 |
|
|
confidence_scores[brand_name].append(('ocr', combined_ocr_score * 0.7)) |
|
|
|
|
|
|
|
|
for brand_info in vlm_result.get('verified_brands', []): |
|
|
brand_name = brand_info['name'] |
|
|
vlm_confidence_level = brand_info.get('confidence', 'medium') |
|
|
vlm_confidence = self.confidence_map.get(vlm_confidence_level.lower(), 0.7) |
|
|
|
|
|
if brand_name not in votes: |
|
|
|
|
|
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None} |
|
|
confidence_scores[brand_name] = [] |
|
|
|
|
|
votes[brand_name]['votes'] += 2 |
|
|
votes[brand_name]['sources'].append('vlm') |
|
|
confidence_scores[brand_name].append(('vlm', vlm_confidence)) |
|
|
|
|
|
|
|
|
for false_positive in vlm_result.get('false_positives', []): |
|
|
if false_positive in votes: |
|
|
|
|
|
votes[false_positive]['votes'] = max(0, votes[false_positive]['votes'] - 2) |
|
|
|
|
|
|
|
|
final_brands = [] |
|
|
for brand_name, vote_info in votes.items(): |
|
|
if vote_info['votes'] <= 0: |
|
|
continue |
|
|
|
|
|
|
|
|
scores = confidence_scores.get(brand_name, []) |
|
|
if not scores: |
|
|
continue |
|
|
|
|
|
|
|
|
weighted_sum = 0.0 |
|
|
weight_total = 0.0 |
|
|
|
|
|
for source, score in scores: |
|
|
if source == 'vlm': |
|
|
weight = 1.0 |
|
|
elif source == 'openclip': |
|
|
weight = 0.6 |
|
|
else: |
|
|
weight = 0.4 |
|
|
|
|
|
weighted_sum += score * weight |
|
|
weight_total += weight |
|
|
|
|
|
avg_confidence = weighted_sum / weight_total if weight_total > 0 else 0.0 |
|
|
|
|
|
|
|
|
if vote_info['votes'] >= 2: |
|
|
avg_confidence *= 1.15 |
|
|
|
|
|
|
|
|
avg_confidence = min(avg_confidence, 0.95) |
|
|
|
|
|
|
|
|
if avg_confidence > 0.30: |
|
|
final_brands.append((brand_name, avg_confidence, vote_info['bbox'])) |
|
|
|
|
|
|
|
|
final_brands.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
return final_brands |
|
|
|
|
|
def extract_visual_evidence(self, image: Image.Image, brand_name: str) -> Dict: |
|
|
""" |
|
|
Extract detailed visual evidence for identified brand |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
brand_name: Identified brand name |
|
|
|
|
|
Returns: |
|
|
Dictionary with evidence description |
|
|
""" |
|
|
evidence_prompt = f"""You identified {brand_name} in this image. Please describe the specific visual evidence: |
|
|
|
|
|
1. Logo appearance: Describe the logo's shape, style, color, and exact location in the image |
|
|
2. Text elements: What text did you see? (exact wording, font style, placement) |
|
|
3. Distinctive patterns: Any signature patterns, textures, or design elements |
|
|
4. Color scheme: Brand-specific colors used |
|
|
5. Product features: Distinctive product design characteristics |
|
|
|
|
|
Be specific and detailed. Focus on objective visual features.""" |
|
|
|
|
|
try: |
|
|
evidence_description = self._generate_vlm_response(image, evidence_prompt) |
|
|
|
|
|
return { |
|
|
'brand': brand_name, |
|
|
'evidence_description': evidence_description, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
'brand': brand_name, |
|
|
'evidence_description': f"Evidence extraction failed: {str(e)}", |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
def _generate_vlm_response(self, image: Image.Image, prompt: str) -> str: |
|
|
""" |
|
|
Generate VLM response for given image and prompt |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
prompt: Text prompt |
|
|
|
|
|
Returns: |
|
|
VLM response string |
|
|
""" |
|
|
from qwen_vl_utils import process_vision_info |
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
}] |
|
|
|
|
|
text = self.caption_generator.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = self.caption_generator.processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt" |
|
|
).to(self.caption_generator.model.device) |
|
|
|
|
|
|
|
|
generation_config = { |
|
|
'temperature': 0.3, |
|
|
'top_p': 0.9, |
|
|
'max_new_tokens': 300, |
|
|
'repetition_penalty': 1.1 |
|
|
} |
|
|
|
|
|
generated_ids = self.caption_generator.model.generate( |
|
|
**inputs, |
|
|
**generation_config |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
output_text = self.caption_generator.processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
return output_text |
|
|
|
|
|
def _parse_verification_response(self, response: str) -> Dict: |
|
|
""" |
|
|
Parse VLM verification response |
|
|
|
|
|
Args: |
|
|
response: VLM response string |
|
|
|
|
|
Returns: |
|
|
Parsed dictionary |
|
|
""" |
|
|
try: |
|
|
|
|
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
result = json.loads(json_match.group()) |
|
|
return result |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
return self._rule_based_parse(response) |
|
|
|
|
|
def _rule_based_parse(self, response: str) -> Dict: |
|
|
""" |
|
|
Fallback rule-based parsing if JSON fails |
|
|
|
|
|
Args: |
|
|
response: VLM response string |
|
|
|
|
|
Returns: |
|
|
Parsed dictionary |
|
|
""" |
|
|
result = { |
|
|
'verified_brands': [], |
|
|
'false_positives': [], |
|
|
'additional_brands': [] |
|
|
} |
|
|
|
|
|
|
|
|
lines = response.lower().split('\n') |
|
|
|
|
|
for line in lines: |
|
|
|
|
|
if any(word in line for word in ['correct', 'yes', 'visible', 'see', 'identified']): |
|
|
|
|
|
words = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', response) |
|
|
for word in words: |
|
|
if len(word) > 2: |
|
|
result['verified_brands'].append({ |
|
|
'name': word, |
|
|
'confidence': 'medium', |
|
|
'evidence': 'Extracted from VLM response' |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
print("✓ BrandVerificationManager (VLM verification and voting) defined") |
|
|
|