Pixcribe / openclip_semantic_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import torch
import open_clip
from PIL import Image
from typing import List, Dict
import numpy as np
class OpenCLIPSemanticManager:
"""Zero-shot classification and visual feature extraction with enhanced scene understanding"""
def __init__(self):
print("Loading OpenCLIP ViT-H/14 model...")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
'ViT-H-14',
pretrained='laion2b_s32b_b79k'
)
self.tokenizer = open_clip.get_tokenizer('ViT-H-14')
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model.eval()
# Enhanced scene vocabularies
self.scene_vocabularies = {
'urban': [
'city canyon with tall buildings',
'downtown street with skyscrapers',
'urban corridor between buildings',
'busy city intersection',
'metropolitan avenue'
],
'lighting': [
'overcast cloudy day',
'bright sunny day',
'golden hour warm glow',
'blue hour twilight',
'harsh midday sun',
'soft diffused light',
'dramatic evening light',
'moody overcast atmosphere'
],
'mood': [
'bustling and energetic',
'calm and contemplative',
'dramatic and imposing',
'intimate and cozy',
'vibrant and lively'
]
}
# Hierarchical vocabularies
self.coarse_labels = [
'furniture', 'musical instrument', 'artwork',
'appliance', 'decoration', 'tool', 'electronic device',
'clothing', 'accessory', 'food', 'plant'
]
self.domain_vocabularies = {
'musical instrument': [
'acoustic guitar', 'electric guitar', 'bass guitar',
'classical guitar', 'ukulele', 'violin', 'cello',
'piano', 'keyboard', 'drums', 'saxophone', 'trumpet'
],
'furniture': [
'chair', 'sofa', 'table', 'desk', 'shelf',
'cabinet', 'bed', 'stool', 'bench', 'wardrobe'
],
'electronic device': [
'smartphone', 'laptop', 'tablet', 'camera',
'headphones', 'speaker', 'monitor', 'keyboard', 'mouse'
],
'clothing': [
'shirt', 'pants', 'dress', 'jacket', 'coat',
'sweater', 'skirt', 'jeans', 'hoodie'
],
'accessory': [
'watch', 'sunglasses', 'hat', 'scarf', 'belt',
'bag', 'wallet', 'jewelry', 'tie'
]
}
self.text_features_cache = {}
self._cache_text_features()
print("✓ OpenCLIP loaded with enhanced scene understanding")
def _cache_text_features(self):
"""Pre-compute and cache text features"""
with torch.no_grad():
# Cache coarse labels
prompts = [f"a photo of {label}" for label in self.coarse_labels]
text = self.tokenizer(prompts)
if torch.cuda.is_available():
text = text.cuda()
self.text_features_cache['coarse'] = self.model.encode_text(text)
self.text_features_cache['coarse'] /= self.text_features_cache['coarse'].norm(dim=-1, keepdim=True)
# Cache domain vocabularies
for domain, labels in self.domain_vocabularies.items():
prompts = [f"a photo of {label}" for label in labels]
text = self.tokenizer(prompts)
if torch.cuda.is_available():
text = text.cuda()
features = self.model.encode_text(text)
features /= features.norm(dim=-1, keepdim=True)
self.text_features_cache[domain] = features
# Cache scene vocabularies
for scene_type, labels in self.scene_vocabularies.items():
text = self.tokenizer(labels)
if torch.cuda.is_available():
text = text.cuda()
features = self.model.encode_text(text)
features /= features.norm(dim=-1, keepdim=True)
self.text_features_cache[f'scene_{scene_type}'] = features
def analyze_scene(self, image: Image.Image) -> Dict:
"""Comprehensive scene analysis"""
image_features = self.encode_image(image)
scene_analysis = {}
# Analyze each scene aspect
for scene_type in ['urban', 'lighting', 'mood']:
cache_key = f'scene_{scene_type}'
similarity = (image_features @ self.text_features_cache[cache_key].T) / 0.01
probs = similarity.softmax(dim=-1)
results = {}
for i, label in enumerate(self.scene_vocabularies[scene_type]):
results[label] = float(probs[0, i].cpu())
top_result = max(results.items(), key=lambda x: x[1])
scene_analysis[scene_type] = {
'top': top_result[0],
'confidence': top_result[1],
'all_scores': results
}
return scene_analysis
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image to feature vector"""
with torch.no_grad():
image_tensor = self.preprocess(image).unsqueeze(0)
if torch.cuda.is_available():
image_tensor = image_tensor.cuda()
image_features = self.model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def encode_text(self, text_list: List[str]) -> torch.Tensor:
"""Encode text list to feature vectors"""
with torch.no_grad():
prompts = [f"a photo of {text}" for text in text_list]
text = self.tokenizer(prompts)
if torch.cuda.is_available():
text = text.cuda()
text_features = self.model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features
def classify_zero_shot(self, image: Image.Image, candidate_labels: List[str]) -> Dict[str, float]:
"""Zero-shot classification"""
image_features = self.encode_image(image)
text_features = self.encode_text(candidate_labels)
similarity = (image_features @ text_features.T) / 0.01
probs = similarity.softmax(dim=-1)
results = {}
for i, label in enumerate(candidate_labels):
results[label] = float(probs[0, i].cpu())
return results
def classify_hierarchical(self, image: Image.Image) -> Dict:
"""Hierarchical classification"""
image_features = self.encode_image(image)
coarse_similarity = (image_features @ self.text_features_cache['coarse'].T) / 0.01
coarse_probs = coarse_similarity.softmax(dim=-1)
coarse_results = {}
for i, label in enumerate(self.coarse_labels):
coarse_results[label] = float(coarse_probs[0, i].cpu())
top_category = max(coarse_results, key=coarse_results.get)
if top_category in self.domain_vocabularies:
fine_labels = self.domain_vocabularies[top_category]
fine_similarity = (image_features @ self.text_features_cache[top_category].T) / 0.01
fine_probs = fine_similarity.softmax(dim=-1)
fine_results = {}
for i, label in enumerate(fine_labels):
fine_results[label] = float(fine_probs[0, i].cpu())
top_prediction = max(fine_results, key=fine_results.get)
return {
'coarse': top_category,
'fine': fine_results,
'top_prediction': top_prediction,
'confidence': fine_results[top_prediction]
}
return {
'coarse': top_category,
'top_prediction': top_category,
'confidence': coarse_results[top_category]
}
print("✓ OpenCLIPSemanticManager defined")