Pixcribe / brand_verification_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import torch
import json
import re
from PIL import Image
from typing import List, Dict, Tuple
from datetime import datetime
from caption_generation_manager import CaptionGenerationManager
class BrandVerificationManager:
"""VLM-based brand verification and three-way voting system"""
def __init__(self, caption_generator: CaptionGenerationManager = None):
"""
Args:
caption_generator: CaptionGenerationManager instance for VLM access
"""
if caption_generator is None:
caption_generator = CaptionGenerationManager()
self.caption_generator = caption_generator
# Confidence mapping for VLM responses
self.confidence_map = {
'high': 0.9,
'medium': 0.7,
'low': 0.5,
'very high': 0.95,
'very low': 0.3
}
print("✓ Brand Verification Manager initialized with VLM")
def verify_brands(self, image: Image.Image, detected_brands: List[Tuple[str, float, list]]) -> Dict:
"""
Use VLM to verify detected brands
Args:
image: PIL Image
detected_brands: List of (brand_name, confidence, bbox) tuples
Returns:
Dictionary with verification results
"""
if not detected_brands:
return {
'verified_brands': [],
'false_positives': [],
'additional_brands': [],
'confidence': 0.0
}
# Construct verification prompt
brand_list = ', '.join([f"{brand[0]} (confidence: {brand[1]:.2f})"
for brand in detected_brands[:3]]) # Top 3 brands
verification_prompt = f"""Analyze this image carefully. Our computer vision system detected the following brands: {brand_list}.
Please verify each brand identification:
1. Are these brand identifications correct based on visible logos, patterns, text, or distinctive features?
2. If incorrect, what brands do you actually see (if any)?
3. Describe the visual evidence (logo shape, text, pattern, color scheme, hardware) that supports your conclusion.
Respond in JSON format:
{{
"verified_brands": [
{{"name": "Brand Name", "confidence": "high/medium/low", "evidence": "description of visual evidence"}}
],
"false_positives": ["brand names that were incorrectly detected"],
"additional_brands": ["brands we missed but you can see"]
}}
IMPORTANT: Only include brands you can clearly identify with visual evidence. If unsure, use "low" confidence."""
# Generate VLM response
try:
response = self._generate_vlm_response(image, verification_prompt)
parsed_result = self._parse_verification_response(response)
return parsed_result
except Exception as e:
print(f"VLM verification error: {e}")
# Fallback to original detections
return {
'verified_brands': [
{'name': brand[0], 'confidence': 'medium', 'evidence': 'VLM verification failed'}
for brand in detected_brands
],
'false_positives': [],
'additional_brands': []
}
def three_way_voting(self, openclip_brands: List[Tuple], ocr_brands: Dict,
vlm_result: Dict) -> List[Tuple[str, float, list]]:
"""
Three-way voting: OpenCLIP vs OCR vs VLM
Args:
openclip_brands: List of (brand_name, confidence, bbox) from OpenCLIP
ocr_brands: Dict of {brand_name: (text_score, ocr_conf)} from OCR
vlm_result: Verification result from VLM
Returns:
List of (brand_name, final_confidence, bbox) tuples
"""
votes = {} # brand_name -> {votes: int, sources: list, bbox: list}
confidence_scores = {} # brand_name -> list of (source, confidence)
# Vote 1: OpenCLIP
for brand_name, confidence, bbox in openclip_brands:
if brand_name not in votes:
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': bbox}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 1
votes[brand_name]['sources'].append('openclip')
confidence_scores[brand_name].append(('openclip', confidence * 0.8))
# Vote 2: OCR
for brand_name, (text_score, ocr_conf) in ocr_brands.items():
if brand_name not in votes:
# OCR found a brand not detected by OpenCLIP
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 1
votes[brand_name]['sources'].append('ocr')
combined_ocr_score = (text_score + ocr_conf) / 2
confidence_scores[brand_name].append(('ocr', combined_ocr_score * 0.7))
# Vote 3: VLM (double weight - most reliable)
for brand_info in vlm_result.get('verified_brands', []):
brand_name = brand_info['name']
vlm_confidence_level = brand_info.get('confidence', 'medium')
vlm_confidence = self.confidence_map.get(vlm_confidence_level.lower(), 0.7)
if brand_name not in votes:
# VLM found a brand missed by both OpenCLIP and OCR
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 2 # VLM gets double vote
votes[brand_name]['sources'].append('vlm')
confidence_scores[brand_name].append(('vlm', vlm_confidence))
# Remove false positives flagged by VLM
for false_positive in vlm_result.get('false_positives', []):
if false_positive in votes:
# Reduce votes significantly
votes[false_positive]['votes'] = max(0, votes[false_positive]['votes'] - 2)
# Calculate final scores
final_brands = []
for brand_name, vote_info in votes.items():
if vote_info['votes'] <= 0:
continue # Skip brands with no votes
# Calculate weighted average confidence
scores = confidence_scores.get(brand_name, [])
if not scores:
continue
# VLM has highest weight, OpenCLIP medium, OCR lowest
weighted_sum = 0.0
weight_total = 0.0
for source, score in scores:
if source == 'vlm':
weight = 1.0
elif source == 'openclip':
weight = 0.6
else: # ocr
weight = 0.4
weighted_sum += score * weight
weight_total += weight
avg_confidence = weighted_sum / weight_total if weight_total > 0 else 0.0
# Boost confidence if multiple sources agree
if vote_info['votes'] >= 2:
avg_confidence *= 1.15 # 15% boost for agreement
# Cap at 0.95
avg_confidence = min(avg_confidence, 0.95)
# Only include if confidence is reasonable
if avg_confidence > 0.30:
final_brands.append((brand_name, avg_confidence, vote_info['bbox']))
# Sort by confidence
final_brands.sort(key=lambda x: x[1], reverse=True)
return final_brands
def extract_visual_evidence(self, image: Image.Image, brand_name: str) -> Dict:
"""
Extract detailed visual evidence for identified brand
Args:
image: PIL Image
brand_name: Identified brand name
Returns:
Dictionary with evidence description
"""
evidence_prompt = f"""You identified {brand_name} in this image. Please describe the specific visual evidence:
1. Logo appearance: Describe the logo's shape, style, color, and exact location in the image
2. Text elements: What text did you see? (exact wording, font style, placement)
3. Distinctive patterns: Any signature patterns, textures, or design elements
4. Color scheme: Brand-specific colors used
5. Product features: Distinctive product design characteristics
Be specific and detailed. Focus on objective visual features."""
try:
evidence_description = self._generate_vlm_response(image, evidence_prompt)
return {
'brand': brand_name,
'evidence_description': evidence_description,
'timestamp': datetime.now().isoformat()
}
except Exception as e:
return {
'brand': brand_name,
'evidence_description': f"Evidence extraction failed: {str(e)}",
'timestamp': datetime.now().isoformat()
}
def _generate_vlm_response(self, image: Image.Image, prompt: str) -> str:
"""
Generate VLM response for given image and prompt
Args:
image: PIL Image
prompt: Text prompt
Returns:
VLM response string
"""
from qwen_vl_utils import process_vision_info
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
text = self.caption_generator.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.caption_generator.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.caption_generator.model.device)
# Generate with low temperature for factual responses
generation_config = {
'temperature': 0.3, # Low temperature for factual verification
'top_p': 0.9,
'max_new_tokens': 300,
'repetition_penalty': 1.1
}
generated_ids = self.caption_generator.model.generate(
**inputs,
**generation_config
)
# Trim input tokens
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.caption_generator.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return output_text
def _parse_verification_response(self, response: str) -> Dict:
"""
Parse VLM verification response
Args:
response: VLM response string
Returns:
Parsed dictionary
"""
try:
# Try to extract JSON from response
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except json.JSONDecodeError:
pass
# Fallback: rule-based parsing
return self._rule_based_parse(response)
def _rule_based_parse(self, response: str) -> Dict:
"""
Fallback rule-based parsing if JSON fails
Args:
response: VLM response string
Returns:
Parsed dictionary
"""
result = {
'verified_brands': [],
'false_positives': [],
'additional_brands': []
}
# Simple pattern matching
lines = response.lower().split('\n')
for line in lines:
# Look for brand names mentioned with positive sentiment
if any(word in line for word in ['correct', 'yes', 'visible', 'see', 'identified']):
# Extract potential brand names (capitalize words)
words = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', response)
for word in words:
if len(word) > 2: # Avoid short words
result['verified_brands'].append({
'name': word,
'confidence': 'medium',
'evidence': 'Extracted from VLM response'
})
return result
print("✓ BrandVerificationManager (VLM verification and voting) defined")