import torch import open_clip from PIL import Image from typing import List, Dict import numpy as np class OpenCLIPSemanticManager: """Zero-shot classification and visual feature extraction with enhanced scene understanding""" def __init__(self): print("Loading OpenCLIP ViT-H/14 model...") self.model, _, self.preprocess = open_clip.create_model_and_transforms( 'ViT-H-14', pretrained='laion2b_s32b_b79k' ) self.tokenizer = open_clip.get_tokenizer('ViT-H-14') if torch.cuda.is_available(): self.model = self.model.cuda() self.model.eval() # Enhanced scene vocabularies self.scene_vocabularies = { 'urban': [ 'city canyon with tall buildings', 'downtown street with skyscrapers', 'urban corridor between buildings', 'busy city intersection', 'metropolitan avenue' ], 'lighting': [ 'overcast cloudy day', 'bright sunny day', 'golden hour warm glow', 'blue hour twilight', 'harsh midday sun', 'soft diffused light', 'dramatic evening light', 'moody overcast atmosphere' ], 'mood': [ 'bustling and energetic', 'calm and contemplative', 'dramatic and imposing', 'intimate and cozy', 'vibrant and lively' ] } # Hierarchical vocabularies self.coarse_labels = [ 'furniture', 'musical instrument', 'artwork', 'appliance', 'decoration', 'tool', 'electronic device', 'clothing', 'accessory', 'food', 'plant' ] self.domain_vocabularies = { 'musical instrument': [ 'acoustic guitar', 'electric guitar', 'bass guitar', 'classical guitar', 'ukulele', 'violin', 'cello', 'piano', 'keyboard', 'drums', 'saxophone', 'trumpet' ], 'furniture': [ 'chair', 'sofa', 'table', 'desk', 'shelf', 'cabinet', 'bed', 'stool', 'bench', 'wardrobe' ], 'electronic device': [ 'smartphone', 'laptop', 'tablet', 'camera', 'headphones', 'speaker', 'monitor', 'keyboard', 'mouse' ], 'clothing': [ 'shirt', 'pants', 'dress', 'jacket', 'coat', 'sweater', 'skirt', 'jeans', 'hoodie' ], 'accessory': [ 'watch', 'sunglasses', 'hat', 'scarf', 'belt', 'bag', 'wallet', 'jewelry', 'tie' ] } self.text_features_cache = {} self._cache_text_features() print("✓ OpenCLIP loaded with enhanced scene understanding") def _cache_text_features(self): """Pre-compute and cache text features""" with torch.no_grad(): # Cache coarse labels prompts = [f"a photo of {label}" for label in self.coarse_labels] text = self.tokenizer(prompts) if torch.cuda.is_available(): text = text.cuda() self.text_features_cache['coarse'] = self.model.encode_text(text) self.text_features_cache['coarse'] /= self.text_features_cache['coarse'].norm(dim=-1, keepdim=True) # Cache domain vocabularies for domain, labels in self.domain_vocabularies.items(): prompts = [f"a photo of {label}" for label in labels] text = self.tokenizer(prompts) if torch.cuda.is_available(): text = text.cuda() features = self.model.encode_text(text) features /= features.norm(dim=-1, keepdim=True) self.text_features_cache[domain] = features # Cache scene vocabularies for scene_type, labels in self.scene_vocabularies.items(): text = self.tokenizer(labels) if torch.cuda.is_available(): text = text.cuda() features = self.model.encode_text(text) features /= features.norm(dim=-1, keepdim=True) self.text_features_cache[f'scene_{scene_type}'] = features def analyze_scene(self, image: Image.Image) -> Dict: """Comprehensive scene analysis""" image_features = self.encode_image(image) scene_analysis = {} # Analyze each scene aspect for scene_type in ['urban', 'lighting', 'mood']: cache_key = f'scene_{scene_type}' similarity = (image_features @ self.text_features_cache[cache_key].T) / 0.01 probs = similarity.softmax(dim=-1) results = {} for i, label in enumerate(self.scene_vocabularies[scene_type]): results[label] = float(probs[0, i].cpu()) top_result = max(results.items(), key=lambda x: x[1]) scene_analysis[scene_type] = { 'top': top_result[0], 'confidence': top_result[1], 'all_scores': results } return scene_analysis def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image to feature vector""" with torch.no_grad(): image_tensor = self.preprocess(image).unsqueeze(0) if torch.cuda.is_available(): image_tensor = image_tensor.cuda() image_features = self.model.encode_image(image_tensor) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features def encode_text(self, text_list: List[str]) -> torch.Tensor: """Encode text list to feature vectors""" with torch.no_grad(): prompts = [f"a photo of {text}" for text in text_list] text = self.tokenizer(prompts) if torch.cuda.is_available(): text = text.cuda() text_features = self.model.encode_text(text) text_features /= text_features.norm(dim=-1, keepdim=True) return text_features def classify_zero_shot(self, image: Image.Image, candidate_labels: List[str]) -> Dict[str, float]: """Zero-shot classification""" image_features = self.encode_image(image) text_features = self.encode_text(candidate_labels) similarity = (image_features @ text_features.T) / 0.01 probs = similarity.softmax(dim=-1) results = {} for i, label in enumerate(candidate_labels): results[label] = float(probs[0, i].cpu()) return results def classify_hierarchical(self, image: Image.Image) -> Dict: """Hierarchical classification""" image_features = self.encode_image(image) coarse_similarity = (image_features @ self.text_features_cache['coarse'].T) / 0.01 coarse_probs = coarse_similarity.softmax(dim=-1) coarse_results = {} for i, label in enumerate(self.coarse_labels): coarse_results[label] = float(coarse_probs[0, i].cpu()) top_category = max(coarse_results, key=coarse_results.get) if top_category in self.domain_vocabularies: fine_labels = self.domain_vocabularies[top_category] fine_similarity = (image_features @ self.text_features_cache[top_category].T) / 0.01 fine_probs = fine_similarity.softmax(dim=-1) fine_results = {} for i, label in enumerate(fine_labels): fine_results[label] = float(fine_probs[0, i].cpu()) top_prediction = max(fine_results, key=fine_results.get) return { 'coarse': top_category, 'fine': fine_results, 'top_prediction': top_prediction, 'confidence': fine_results[top_prediction] } return { 'coarse': top_category, 'top_prediction': top_category, 'confidence': coarse_results[top_category] } print("✓ OpenCLIPSemanticManager defined")