onurcopur commited on
Commit
f761027
·
1 Parent(s): cbeb83b

first commit

Browse files
Files changed (9) hide show
  1. .env.template +9 -0
  2. DEPLOYMENT.md +51 -0
  3. Dockerfile +38 -0
  4. README.md +56 -4
  5. app.py +30 -0
  6. embeddings.py +352 -0
  7. main.py +497 -0
  8. patch_attention.py +221 -0
  9. requirements.txt +30 -0
.env.template ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Token for VLM inference
2
+ # Required for image captioning using GLM-4.5V model
3
+ HF_TOKEN=your_huggingface_token_here
4
+
5
+ # Optional: Custom port (defaults to 7860 for HF Spaces)
6
+ # PORT=7860
7
+
8
+ # Optional: Model cache directory
9
+ # HF_HOME=/app/cache
DEPLOYMENT.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Deployment Guide
2
+
3
+ ## Environment Variables
4
+
5
+ ### Required Secrets
6
+ Set these in your Hugging Face Space settings:
7
+
8
+ 1. **HF_TOKEN**: Your Hugging Face access token
9
+ - Go to: https://huggingface.co/settings/tokens
10
+ - Create a new token with read access
11
+ - Add as a secret in your Space settings
12
+
13
+ ## Hardware Requirements
14
+
15
+ - **Recommended**: T4 Small or higher for optimal performance
16
+ - **Minimum**: CPU (slower inference)
17
+ - **Memory**: At least 8GB RAM recommended
18
+ - **Storage**: 10GB+ for model caching
19
+
20
+ ## Deployment Steps
21
+
22
+ 1. Push all files to your HF Space repository:
23
+ ```bash
24
+ git add .
25
+ git commit -m "Deploy tattoo search engine"
26
+ git push
27
+ ```
28
+
29
+ 2. Set the HF_TOKEN secret in Space settings
30
+
31
+ 3. The Space will automatically build and deploy
32
+
33
+ ## Testing
34
+
35
+ Once deployed, test these endpoints:
36
+
37
+ - `GET /health` - Health check
38
+ - `GET /models` - Available models
39
+ - `POST /search` - Upload image and search
40
+
41
+ ## Troubleshooting
42
+
43
+ ### Common Issues:
44
+
45
+ 1. **Missing HF_TOKEN**: Set the token in Space secrets
46
+ 2. **Model loading errors**: Check hardware requirements
47
+ 3. **Timeout errors**: Consider upgrading to GPU hardware
48
+ 4. **Memory errors**: Upgrade to larger hardware tier
49
+
50
+ ### Logs:
51
+ Check the Space logs for detailed error messages and startup information.
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.12 as base image
2
+ FROM python:3.12-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ gcc \
10
+ g++ \
11
+ curl \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Set environment variables
15
+ ENV PYTHONDONTWRITEBYTECODE=1
16
+ ENV PYTHONUNBUFFERED=1
17
+ ENV PORT=7860
18
+
19
+ # Copy requirements and install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir --upgrade pip
22
+ RUN pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Copy application code
25
+ COPY . .
26
+
27
+ # Create cache directory for models
28
+ RUN mkdir -p /app/cache
29
+
30
+ # Expose the port
31
+ EXPOSE 7860
32
+
33
+ # Health check
34
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
35
+ CMD curl -f http://localhost:7860/health || exit 1
36
+
37
+ # Run the application
38
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,11 +1,63 @@
1
  ---
2
  title: Tattoo Search Engine
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Tattoo Search Engine
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ app_port: 7860
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
+ # Tattoo Search Engine 🎨
14
+
15
+ A powerful AI-powered tattoo search engine that finds similar tattoos based on visual similarity. Upload an image of a tattoo and discover visually similar designs from across the web.
16
+
17
+ ## Features
18
+
19
+ - **Multi-Model Support**: Choose from CLIP, DINOv2, or SigLIP embedding models
20
+ - **Advanced Search**: Combines image captioning with visual similarity search
21
+ - **Patch Attention Analysis**: Detailed analysis of which parts of tattoos are most similar
22
+ - **Real-time Processing**: Fast image processing and similarity computation
23
+ - **Multiple Platforms**: Searches across various tattoo platforms and image sources
24
+
25
+ ## API Endpoints
26
+
27
+ ### `POST /search`
28
+ Search for similar tattoos by uploading an image.
29
+
30
+ **Parameters:**
31
+ - `file`: Image file (required)
32
+ - `embedding_model`: Model to use - "clip", "dinov2", or "siglip" (default: "clip")
33
+ - `include_patch_attention`: Enable detailed patch analysis (default: false)
34
+
35
+ ### `POST /analyze-attention`
36
+ Analyze patch-level attention between two images.
37
+
38
+ **Parameters:**
39
+ - `query_file`: Query image file (required)
40
+ - `candidate_url`: URL of candidate image to compare (required)
41
+ - `embedding_model`: Model to use (default: "clip")
42
+ - `include_visualizations`: Include attention visualizations (default: true)
43
+
44
+ ### `GET /models`
45
+ Get available embedding models and their configurations.
46
+
47
+ ### `GET /health`
48
+ Health check endpoint.
49
+
50
+ ## Models Used
51
+
52
+ - **Image Captioning**: GLM-4.5V via HuggingFace Inference API
53
+ - **Visual Similarity**: CLIP ViT-B/32, DINOv2, or SigLIP
54
+ - **Search**: Multi-platform web search with intelligent filtering
55
+
56
+ ## Usage
57
+
58
+ 1. Upload a tattoo image
59
+ 2. Select your preferred embedding model
60
+ 3. Get ranked results with similarity scores
61
+ 4. Optionally analyze detailed patch-level similarities
62
+
63
+ Perfect for tattoo enthusiasts, artists, and anyone looking for tattoo inspiration!
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from main import app
4
+
5
+ # Configure logging
6
+ logging.basicConfig(
7
+ level=logging.INFO,
8
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
9
+ )
10
+ logger = logging.getLogger(__name__)
11
+
12
+ if __name__ == "__main__":
13
+ # Get port from environment (Hugging Face Spaces uses 7860)
14
+ port = int(os.environ.get("PORT", 7860))
15
+
16
+ logger.info(f"Starting Tattoo Search Engine on port {port}")
17
+ logger.info("Available endpoints:")
18
+ logger.info(" POST /search - Search for similar tattoos")
19
+ logger.info(" POST /analyze-attention - Analyze patch-level attention")
20
+ logger.info(" GET /models - Get available embedding models")
21
+ logger.info(" GET /health - Health check")
22
+
23
+ import uvicorn
24
+ uvicorn.run(
25
+ app,
26
+ host="0.0.0.0",
27
+ port=port,
28
+ log_level="info",
29
+ access_log=True
30
+ )
embeddings.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class EmbeddingModel(ABC):
12
+ """Abstract base class for embedding models."""
13
+
14
+ def __init__(self, device: torch.device):
15
+ self.device = device
16
+ self.model = None
17
+ self.preprocess = None
18
+
19
+ @abstractmethod
20
+ def load_model(self) -> None:
21
+ """Load the embedding model and preprocessing."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
26
+ """Encode an image into feature vector."""
27
+ pass
28
+
29
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
30
+ """Encode an image into patch-level features. Override in subclasses that support it."""
31
+ raise NotImplementedError("Patch-level encoding not implemented for this model")
32
+
33
+ def compute_patch_attention(self, query_patches: torch.Tensor, candidate_patches: torch.Tensor) -> torch.Tensor:
34
+ """Compute attention weights between query and candidate patches."""
35
+ # query_patches: [num_query_patches, feature_dim]
36
+ # candidate_patches: [num_candidate_patches, feature_dim]
37
+
38
+ # Normalize patches
39
+ query_patches = F.normalize(query_patches, p=2, dim=1)
40
+ candidate_patches = F.normalize(candidate_patches, p=2, dim=1)
41
+
42
+ # Compute attention matrix: [num_query_patches, num_candidate_patches]
43
+ attention_matrix = torch.mm(query_patches, candidate_patches.T)
44
+
45
+ return attention_matrix
46
+
47
+ @abstractmethod
48
+ def get_model_name(self) -> str:
49
+ """Return the model name."""
50
+ pass
51
+
52
+ def compute_similarity(self, query_features: torch.Tensor, candidate_features: torch.Tensor) -> float:
53
+ """Compute similarity between query and candidate features."""
54
+ return torch.mm(query_features, candidate_features.T).item()
55
+
56
+
57
+ class CLIPEmbedding(EmbeddingModel):
58
+ """CLIP-based embedding model."""
59
+
60
+ def __init__(self, device: torch.device, model_name: str = "ViT-B-32"):
61
+ super().__init__(device)
62
+ self.model_name = model_name
63
+ self.tokenizer = None
64
+ self.load_model()
65
+
66
+ def load_model(self) -> None:
67
+ """Load CLIP model and preprocessing."""
68
+ try:
69
+ import open_clip
70
+ logger.info(f"Loading CLIP model: {self.model_name}")
71
+
72
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
73
+ self.model_name, pretrained="openai"
74
+ )
75
+ self.model.to(self.device)
76
+ self.tokenizer = open_clip.get_tokenizer(self.model_name)
77
+
78
+ logger.info(f"CLIP model {self.model_name} loaded successfully")
79
+ except Exception as e:
80
+ logger.error(f"Failed to load CLIP model: {e}")
81
+ raise
82
+
83
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
84
+ """Encode image using CLIP."""
85
+ try:
86
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
87
+
88
+ with torch.no_grad():
89
+ features = self.model.encode_image(image_input)
90
+ features = F.normalize(features, p=2, dim=1)
91
+
92
+ return features
93
+ except Exception as e:
94
+ logger.error(f"Failed to encode image with CLIP: {e}")
95
+ raise
96
+
97
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
98
+ """Encode image patches using CLIP vision transformer."""
99
+ try:
100
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
101
+
102
+ with torch.no_grad():
103
+ # Get patch features from CLIP vision transformer
104
+ vision_model = self.model.visual
105
+
106
+ # Pass through patch embedding and positional encoding
107
+ x = vision_model.conv1(image_input) # shape = [*, width, grid, grid]
108
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
109
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
110
+
111
+ # Add class token and positional embeddings
112
+ x = torch.cat([vision_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
113
+ x = x + vision_model.positional_embedding.to(x.dtype)
114
+
115
+ # Apply layer norm
116
+ x = vision_model.ln_pre(x)
117
+
118
+ x = x.permute(1, 0, 2) # NLD -> LND
119
+
120
+ # Pass through transformer blocks
121
+ for block in vision_model.transformer.resblocks:
122
+ x = block(x)
123
+
124
+ x = x.permute(1, 0, 2) # LND -> NLD
125
+
126
+ # Remove class token to get only patch features
127
+ patch_features = x[:, 1:, :] # [1, num_patches, feature_dim]
128
+ patch_features = vision_model.ln_post(patch_features)
129
+
130
+ # Apply projection if it exists
131
+ if vision_model.proj is not None:
132
+ patch_features = patch_features @ vision_model.proj
133
+
134
+ # Normalize patch features
135
+ patch_features = F.normalize(patch_features, p=2, dim=-1)
136
+
137
+ return patch_features.squeeze(0) # [num_patches, feature_dim]
138
+
139
+ except Exception as e:
140
+ logger.error(f"Failed to encode image patches with CLIP: {e}")
141
+ raise
142
+
143
+ def get_model_name(self) -> str:
144
+ return f"CLIP-{self.model_name}"
145
+
146
+
147
+ class DINOv2Embedding(EmbeddingModel):
148
+ """DINOv2-based embedding model."""
149
+
150
+ def __init__(self, device: torch.device, model_name: str = "dinov2_vitb14"):
151
+ super().__init__(device)
152
+ self.model_name = model_name
153
+ self.load_model()
154
+
155
+ def load_model(self) -> None:
156
+ """Load DINOv2 model and preprocessing."""
157
+ try:
158
+ import torch.hub
159
+ from torchvision import transforms
160
+
161
+ logger.info(f"Loading DINOv2 model: {self.model_name}")
162
+
163
+ # Load DINOv2 model from torch hub
164
+ self.model = torch.hub.load('facebookresearch/dinov2', self.model_name)
165
+ self.model.to(self.device)
166
+ self.model.eval()
167
+
168
+ # DINOv2 preprocessing
169
+ self.preprocess = transforms.Compose([
170
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
171
+ transforms.CenterCrop(224),
172
+ transforms.ToTensor(),
173
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
174
+ ])
175
+
176
+ logger.info(f"DINOv2 model {self.model_name} loaded successfully")
177
+ except Exception as e:
178
+ logger.error(f"Failed to load DINOv2 model: {e}")
179
+ raise
180
+
181
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
182
+ """Encode image using DINOv2."""
183
+ try:
184
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
185
+
186
+ with torch.no_grad():
187
+ features = self.model(image_input)
188
+ features = F.normalize(features, p=2, dim=1)
189
+
190
+ return features
191
+ except Exception as e:
192
+ logger.error(f"Failed to encode image with DINOv2: {e}")
193
+ raise
194
+
195
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
196
+ """Encode image patches using DINOv2."""
197
+ try:
198
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
199
+
200
+ with torch.no_grad():
201
+ # Get patch features from DINOv2
202
+ # DINOv2 forward_features returns dict with 'x_norm_patchtokens' containing patch features
203
+ features_dict = self.model.forward_features(image_input)
204
+ patch_features = features_dict['x_norm_patchtokens'] # [1, num_patches, feature_dim]
205
+
206
+ # Normalize patch features
207
+ patch_features = F.normalize(patch_features, p=2, dim=-1)
208
+
209
+ return patch_features.squeeze(0) # [num_patches, feature_dim]
210
+
211
+ except Exception as e:
212
+ logger.error(f"Failed to encode image patches with DINOv2: {e}")
213
+ raise
214
+
215
+ def get_model_name(self) -> str:
216
+ return f"DINOv2-{self.model_name}"
217
+
218
+
219
+ class SigLIPEmbedding(EmbeddingModel):
220
+ """SigLIP-based embedding model."""
221
+
222
+ def __init__(self, device: torch.device, model_name: str = "google/siglip-base-patch16-224"):
223
+ super().__init__(device)
224
+ self.model_name = model_name
225
+ self.processor = None
226
+ self.load_model()
227
+
228
+ def load_model(self) -> None:
229
+ """Load SigLIP model and preprocessing."""
230
+ try:
231
+ # Check for required dependencies
232
+ try:
233
+ import sentencepiece
234
+ except ImportError:
235
+ raise ImportError(
236
+ "SentencePiece is required for SigLIP. Install with: pip install sentencepiece"
237
+ )
238
+
239
+ from transformers import SiglipVisionModel, SiglipProcessor
240
+
241
+ logger.info(f"Loading SigLIP model: {self.model_name}")
242
+
243
+ self.model = SiglipVisionModel.from_pretrained(self.model_name)
244
+ self.model.to(self.device)
245
+ self.model.eval()
246
+
247
+ self.processor = SiglipProcessor.from_pretrained(self.model_name)
248
+
249
+ logger.info(f"SigLIP model {self.model_name} loaded successfully")
250
+ except Exception as e:
251
+ logger.error(f"Failed to load SigLIP model: {e}")
252
+ raise
253
+
254
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
255
+ """Encode image using SigLIP."""
256
+ try:
257
+ inputs = self.processor(images=image, return_tensors="pt")
258
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
259
+
260
+ with torch.no_grad():
261
+ outputs = self.model(**inputs)
262
+ features = outputs.last_hidden_state.mean(dim=1) # Global average pooling
263
+ features = F.normalize(features, p=2, dim=1)
264
+
265
+ return features
266
+ except Exception as e:
267
+ logger.error(f"Failed to encode image with SigLIP: {e}")
268
+ raise
269
+
270
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
271
+ """Encode image patches using SigLIP."""
272
+ try:
273
+ inputs = self.processor(images=image, return_tensors="pt")
274
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
275
+
276
+ with torch.no_grad():
277
+ outputs = self.model(**inputs)
278
+ # last_hidden_state contains patch features: [1, num_patches, feature_dim]
279
+ patch_features = outputs.last_hidden_state
280
+
281
+ # Normalize patch features
282
+ patch_features = F.normalize(patch_features, p=2, dim=-1)
283
+
284
+ return patch_features.squeeze(0) # [num_patches, feature_dim]
285
+
286
+ except Exception as e:
287
+ logger.error(f"Failed to encode image patches with SigLIP: {e}")
288
+ raise
289
+
290
+ def get_model_name(self) -> str:
291
+ return f"SigLIP-{self.model_name.split('/')[-1]}"
292
+
293
+
294
+ class EmbeddingModelFactory:
295
+ """Factory class for creating embedding models."""
296
+
297
+ AVAILABLE_MODELS = {
298
+ "clip": CLIPEmbedding,
299
+ "dinov2": DINOv2Embedding,
300
+ "siglip": SigLIPEmbedding,
301
+ }
302
+
303
+ @classmethod
304
+ def create_model(cls, model_type: str, device: torch.device, **kwargs) -> EmbeddingModel:
305
+ """Create an embedding model instance.
306
+
307
+ Args:
308
+ model_type: Type of model ('clip', 'dinov2', 'siglip')
309
+ device: PyTorch device
310
+ **kwargs: Additional arguments for specific models
311
+
312
+ Returns:
313
+ EmbeddingModel instance
314
+ """
315
+ if model_type.lower() not in cls.AVAILABLE_MODELS:
316
+ raise ValueError(f"Unknown model type: {model_type}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
317
+
318
+ model_class = cls.AVAILABLE_MODELS[model_type.lower()]
319
+
320
+ try:
321
+ return model_class(device, **kwargs)
322
+ except Exception as e:
323
+ logger.error(f"Failed to create {model_type} model: {e}")
324
+ # Fallback to CLIP if the requested model fails
325
+ if model_type.lower() != 'clip':
326
+ logger.info("Falling back to CLIP model")
327
+ return cls.AVAILABLE_MODELS['clip'](device, **kwargs)
328
+ else:
329
+ raise
330
+
331
+ @classmethod
332
+ def get_available_models(cls) -> List[str]:
333
+ """Get list of available model types."""
334
+ return list(cls.AVAILABLE_MODELS.keys())
335
+
336
+
337
+ def get_default_model_configs() -> Dict[str, Dict[str, Any]]:
338
+ """Get default configurations for each model type."""
339
+ return {
340
+ "clip": {
341
+ "model_name": "ViT-B-32",
342
+ "description": "OpenAI CLIP model - good general purpose vision-language model"
343
+ },
344
+ "dinov2": {
345
+ "model_name": "dinov2_vitb14",
346
+ "description": "Meta DINOv2 - self-supervised vision transformer, good for visual features"
347
+ },
348
+ "siglip": {
349
+ "model_name": "google/siglip-base-patch16-224",
350
+ "description": "Google SigLIP - improved CLIP-like model with better training"
351
+ }
352
+ }
main.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import requests
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from dotenv import load_dotenv
15
+ from fastapi import FastAPI, File, HTTPException, UploadFile, Query
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from huggingface_hub import InferenceClient
18
+ from PIL import Image
19
+ from search_engines import SearchEngineManager
20
+ from utils import SearchCache, URLValidator
21
+ from embeddings import EmbeddingModelFactory, EmbeddingModel, get_default_model_configs
22
+ from patch_attention import PatchAttentionAnalyzer
23
+
24
+ # Load environment variables from .env file
25
+ load_dotenv()
26
+
27
+ # Configuration
28
+ HF_TOKEN = os.getenv("HF_TOKEN")
29
+ if not HF_TOKEN:
30
+ raise ValueError("HF_TOKEN environment variable is required")
31
+
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ app = FastAPI(title="Tattoo Search Engine", version="1.0.0")
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+
46
+ class TattooSearchEngine:
47
+ def __init__(self, embedding_model_type: str = "clip"):
48
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ logger.info(f"Using device: {self.device}")
50
+
51
+ # Initialize HuggingFace InferenceClient for VLM captioning
52
+ logger.info("Initializing HuggingFace InferenceClient...")
53
+ self.client = InferenceClient(
54
+ provider="novita",
55
+ api_key=HF_TOKEN,
56
+ )
57
+ self.vlm_model = "zai-org/GLM-4.5V"
58
+ logger.info(f"Using VLM model: {self.vlm_model}")
59
+
60
+ # Load embedding model
61
+ logger.info(f"Loading embedding model: {embedding_model_type}")
62
+ self.embedding_model = EmbeddingModelFactory.create_model(
63
+ embedding_model_type, self.device
64
+ )
65
+ logger.info(f"Using embedding model: {self.embedding_model.get_model_name()}")
66
+
67
+ # Initialize new search system
68
+ logger.info("Initializing search system...")
69
+ self.search_manager = SearchEngineManager(max_workers=5)
70
+ self.url_validator = URLValidator(max_workers=10, timeout=10)
71
+ self.search_cache = SearchCache(default_ttl=3600, max_size=1000)
72
+
73
+ # Setup enhanced web scraping
74
+ self.user_agents = [
75
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
76
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
77
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
78
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15",
79
+ "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
80
+ ]
81
+
82
+ logger.info("Search system initialized successfully!")
83
+
84
+ def generate_caption(self, image: Image.Image) -> str:
85
+ """Generate tattoo caption using HuggingFace InferenceClient."""
86
+ try:
87
+ # Convert PIL image to base64 URL format
88
+ img_buffer = io.BytesIO()
89
+ image.save(img_buffer, format="JPEG", quality=95)
90
+ img_buffer.seek(0)
91
+
92
+ # Create image URL for the API
93
+ import base64
94
+
95
+ image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
96
+ image_url = f"data:image/jpeg;base64,{image_b64}"
97
+
98
+ # completion = self.client.chat.completions.create(
99
+ # model=self.vlm_model,
100
+ # messages=[
101
+ # {
102
+ # "role": "user",
103
+ # "content": [
104
+ # {
105
+ # "type": "text",
106
+ # "text": "Generate a one search engine query to find the most similar tattoos to this image. Response in json format",
107
+ # },
108
+ # {
109
+ # "type": "image_url",
110
+ # "image_url": {"url": image_url},
111
+ # },
112
+ # ],
113
+ # }
114
+ # ],
115
+ # )
116
+ caption = '<|begin_of_box|>{"search_query": "hand tattoo geometric human figure abstract blackwork"}<|end_of_box|>'
117
+ # caption = completion.choices[0].message.content
118
+ if caption:
119
+ match = re.search(r"\{.*\}", caption)
120
+ if match:
121
+ data = json.loads(match.group())
122
+ search_query = data["search_query"]
123
+ return search_query
124
+
125
+ else:
126
+ logger.warning("No caption generated from VLM")
127
+ return "tattoo artwork"
128
+
129
+ except Exception as e:
130
+ logger.error(f"Failed to generate caption: {e}")
131
+ return "tattoo artwork"
132
+
133
+ def search_images(self, query: str, max_results: int = 50) -> List[str]:
134
+ """Search for tattoo images across multiple platforms with caching and validation."""
135
+ # Check cache first
136
+ cache_key = SearchCache.create_cache_key(query, max_results)
137
+ cached_result = self.search_cache.get(cache_key)
138
+ if cached_result:
139
+ logger.info(f"Cache hit for query: {query}")
140
+ return cached_result
141
+
142
+ logger.info(f"Searching for images: {query}")
143
+
144
+ # Use new search system with fallback
145
+ search_result = self.search_manager.search_with_fallback(
146
+ query=query, max_results=max_results, min_results_threshold=10
147
+ )
148
+
149
+ # Extract URLs from search results
150
+ urls = [image.url for image in search_result.images]
151
+
152
+ if not urls:
153
+ logger.warning(f"No URLs found for query: {query}")
154
+ return []
155
+
156
+ # Validate URLs
157
+ logger.info(f"Validating {len(urls)} URLs...")
158
+ valid_urls = self.url_validator.validate_urls(urls)
159
+
160
+ if not valid_urls:
161
+ logger.warning(f"No valid URLs found for query: {query}")
162
+ return []
163
+
164
+ # Cache the result
165
+ self.search_cache.set(cache_key, valid_urls, ttl=3600)
166
+
167
+ logger.info(
168
+ f"Search completed: {len(valid_urls)} valid URLs from "
169
+ f"{len(search_result.platforms_used)} platforms in "
170
+ f"{search_result.search_duration:.2f}s"
171
+ )
172
+
173
+ return valid_urls[:max_results]
174
+
175
+ def download_image(self, url: str, max_retries: int = 3) -> Image.Image:
176
+ for attempt in range(max_retries):
177
+ try:
178
+ # Instagram-optimized headers
179
+ headers = {
180
+ "User-Agent": random.choice(self.user_agents),
181
+ "Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
182
+ "Accept-Language": "en-US,en;q=0.9",
183
+ "Accept-Encoding": "gzip, deflate, br",
184
+ "DNT": "1",
185
+ "Connection": "keep-alive",
186
+ "Upgrade-Insecure-Requests": "1",
187
+ "Sec-Fetch-Dest": "image",
188
+ "Sec-Fetch-Mode": "no-cors",
189
+ "Sec-Fetch-Site": "cross-site",
190
+ "Cache-Control": "no-cache",
191
+ "Pragma": "no-cache",
192
+ }
193
+
194
+ # Pinterest-specific headers
195
+ if "pinterest" in url.lower() or "pinimg" in url.lower():
196
+ headers.update(
197
+ {
198
+ "Referer": "https://www.pinterest.com/",
199
+ "Origin": "https://www.pinterest.com",
200
+ "X-Requested-With": "XMLHttpRequest",
201
+ "Sec-Fetch-User": "?1",
202
+ "X-Pinterest-Source": "web",
203
+ "X-APP-VERSION": "web",
204
+ }
205
+ )
206
+ else:
207
+ headers["Referer"] = "https://www.google.com/"
208
+
209
+ response = requests.get(
210
+ url, headers=headers, timeout=15, allow_redirects=True, stream=True
211
+ )
212
+ response.raise_for_status()
213
+
214
+ # Validate content type
215
+ content_type = response.headers.get("content-type", "").lower()
216
+ if not content_type.startswith("image/"):
217
+ logger.warning(f"Invalid content type for {url}: {content_type}")
218
+ return None
219
+
220
+ # Check file size (avoid downloading huge files)
221
+ content_length = response.headers.get("content-length")
222
+ if (
223
+ content_length and int(content_length) > 10 * 1024 * 1024
224
+ ): # 10MB limit
225
+ logger.warning(f"Image too large: {url} ({content_length} bytes)")
226
+ return None
227
+
228
+ # Download and process image
229
+ image_data = response.content
230
+ if len(image_data) < 1024: # Skip very small images (likely broken)
231
+ logger.warning(f"Image too small: {url} ({len(image_data)} bytes)")
232
+ return None
233
+
234
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
235
+
236
+ # Validate image dimensions
237
+ if image.size[0] < 50 or image.size[1] < 50:
238
+ logger.warning(f"Image dimensions too small: {url} {image.size}")
239
+ return None
240
+
241
+ return image
242
+
243
+ except requests.exceptions.RequestException as e:
244
+ if attempt < max_retries - 1:
245
+ wait_time = (2**attempt) + random.uniform(0, 1)
246
+ logger.info(f"Retry {attempt + 1} for {url} in {wait_time:.1f}s")
247
+ time.sleep(wait_time)
248
+ else:
249
+ logger.warning(
250
+ f"Failed to download image {url} after {max_retries} attempts: {e}"
251
+ )
252
+ except Exception as e:
253
+ logger.warning(f"Failed to process image {url}: {e}")
254
+ break
255
+
256
+ return None
257
+
258
+ def download_and_process_image(
259
+ self, url: str, query_features: torch.Tensor, query_image: Image.Image = None,
260
+ include_patch_attention: bool = False
261
+ ) -> Dict[str, Any]:
262
+ """Download and compute similarity for a single image"""
263
+ candidate_image = self.download_image(url)
264
+ if candidate_image is None:
265
+ return None
266
+
267
+ try:
268
+ candidate_features = self.embedding_model.encode_image(candidate_image)
269
+ similarity = self.embedding_model.compute_similarity(query_features, candidate_features)
270
+
271
+ result = {"score": float(similarity), "url": url}
272
+
273
+ # Add patch attention analysis if requested
274
+ if include_patch_attention and query_image is not None:
275
+ try:
276
+ analyzer = PatchAttentionAnalyzer(self.embedding_model)
277
+ patch_data = analyzer.compute_patch_similarities(query_image, candidate_image)
278
+ result["patch_attention"] = {
279
+ "overall_similarity": patch_data["overall_similarity"],
280
+ "query_grid_size": patch_data["query_grid_size"],
281
+ "candidate_grid_size": patch_data["candidate_grid_size"],
282
+ "attention_summary": analyzer.get_similarity_summary(patch_data)
283
+ }
284
+ except Exception as e:
285
+ logger.warning(f"Failed to compute patch attention for {url}: {e}")
286
+ result["patch_attention"] = None
287
+
288
+ return result
289
+
290
+ except Exception as e:
291
+ logger.warning(f"Error processing candidate image {url}: {e}")
292
+ return None
293
+
294
+ def compute_similarity(
295
+ self, query_image: Image.Image, candidate_urls: List[str], include_patch_attention: bool = False
296
+ ) -> List[Dict[str, Any]]:
297
+ # Encode query image using the selected embedding model
298
+ query_features = self.embedding_model.encode_image(query_image)
299
+
300
+ results = []
301
+
302
+ # Use ThreadPoolExecutor for concurrent downloading and processing
303
+ max_workers = min(10, len(candidate_urls)) # Limit concurrent downloads
304
+
305
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
306
+ # Submit all download tasks
307
+ future_to_url = {
308
+ executor.submit(
309
+ self.download_and_process_image, url, query_features, query_image, include_patch_attention
310
+ ): url
311
+ for url in candidate_urls
312
+ }
313
+
314
+ # Process completed downloads with rate limiting
315
+ for future in as_completed(future_to_url):
316
+ url = future_to_url[future]
317
+ try:
318
+ result = future.result()
319
+ if result is not None:
320
+ results.append(result)
321
+
322
+ # Stop early if we have enough good results (unless patch attention is needed)
323
+ target_count = 5 if include_patch_attention else 20
324
+ if len(results) >= target_count:
325
+ # Cancel remaining futures
326
+ for remaining_future in future_to_url:
327
+ remaining_future.cancel()
328
+ break
329
+
330
+ except Exception as e:
331
+ logger.warning(f"Error in concurrent processing for {url}: {e}")
332
+
333
+ # Small delay to be respectful to servers
334
+ time.sleep(0.1)
335
+
336
+ # Sort by similarity score (highest first)
337
+ results.sort(key=lambda x: x["score"], reverse=True)
338
+
339
+ final_count = 3 if include_patch_attention else 15
340
+ return results[:final_count]
341
+
342
+
343
+ # Global variable to store search engine instance
344
+ search_engine = None
345
+
346
+ def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
347
+ """Get or create search engine instance with specified embedding model."""
348
+ global search_engine
349
+ if search_engine is None or search_engine.embedding_model.get_model_name().lower() != embedding_model:
350
+ search_engine = TattooSearchEngine(embedding_model)
351
+ return search_engine
352
+
353
+
354
+ @app.post("/search")
355
+ async def search_tattoos(
356
+ file: UploadFile = File(...),
357
+ embedding_model: str = Query(default="clip", description="Embedding model to use (clip, dinov2, siglip)"),
358
+ include_patch_attention: bool = Query(default=False, description="Include patch-level attention analysis")
359
+ ):
360
+ if not file.content_type.startswith("image/"):
361
+ raise HTTPException(status_code=400, detail="File must be an image")
362
+
363
+ try:
364
+ # Validate embedding model
365
+ available_models = EmbeddingModelFactory.get_available_models()
366
+ if embedding_model not in available_models:
367
+ raise HTTPException(
368
+ status_code=400,
369
+ detail=f"Invalid embedding model. Available: {available_models}"
370
+ )
371
+
372
+ # Get search engine with specified embedding model
373
+ engine = get_search_engine(embedding_model)
374
+
375
+ # Read and process the uploaded image
376
+ image_data = await file.read()
377
+ query_image = Image.open(io.BytesIO(image_data)).convert("RGB")
378
+
379
+ # Generate caption
380
+ logger.info("Generating caption...")
381
+ caption = engine.generate_caption(query_image)
382
+ logger.info(f"Generated caption: {caption}")
383
+
384
+ # Search for candidate images
385
+ logger.info("Searching for candidate images...")
386
+ candidate_urls = engine.search_images(caption, max_results=100)
387
+
388
+ if not candidate_urls:
389
+ return {"caption": caption, "results": [], "embedding_model": engine.embedding_model.get_model_name()}
390
+
391
+ # Compute similarities and rank
392
+ logger.info("Computing similarities...")
393
+ results = engine.compute_similarity(query_image, candidate_urls, include_patch_attention)
394
+
395
+ return {
396
+ "caption": caption,
397
+ "results": results,
398
+ "embedding_model": engine.embedding_model.get_model_name(),
399
+ "patch_attention_enabled": include_patch_attention
400
+ }
401
+
402
+ except Exception as e:
403
+ logger.error(f"Error processing request: {e}")
404
+ raise HTTPException(status_code=500, detail=str(e))
405
+
406
+
407
+ @app.post("/analyze-attention")
408
+ async def analyze_patch_attention(
409
+ query_file: UploadFile = File(...),
410
+ candidate_url: str = Query(..., description="URL of the candidate image to compare"),
411
+ embedding_model: str = Query(default="clip", description="Embedding model to use (clip, dinov2, siglip)"),
412
+ include_visualizations: bool = Query(default=True, description="Include attention visualizations")
413
+ ):
414
+ """Analyze patch-level attention between query image and a specific candidate image."""
415
+ if not query_file.content_type.startswith("image/"):
416
+ raise HTTPException(status_code=400, detail="Query file must be an image")
417
+
418
+ try:
419
+ # Validate embedding model
420
+ available_models = EmbeddingModelFactory.get_available_models()
421
+ if embedding_model not in available_models:
422
+ raise HTTPException(
423
+ status_code=400,
424
+ detail=f"Invalid embedding model. Available: {available_models}"
425
+ )
426
+
427
+ # Get search engine with specified embedding model
428
+ engine = get_search_engine(embedding_model)
429
+
430
+ # Read query image
431
+ query_image_data = await query_file.read()
432
+ query_image = Image.open(io.BytesIO(query_image_data)).convert("RGB")
433
+
434
+ # Download candidate image
435
+ candidate_image = engine.download_image(candidate_url)
436
+ if candidate_image is None:
437
+ raise HTTPException(status_code=400, detail="Failed to download candidate image")
438
+
439
+ # Analyze patch attention
440
+ analyzer = PatchAttentionAnalyzer(engine.embedding_model)
441
+ similarity_data = analyzer.compute_patch_similarities(query_image, candidate_image)
442
+
443
+ result = {
444
+ "query_image_size": query_image.size,
445
+ "candidate_image_size": candidate_image.size,
446
+ "candidate_url": candidate_url,
447
+ "embedding_model": engine.embedding_model.get_model_name(),
448
+ "similarity_analysis": analyzer.get_similarity_summary(similarity_data),
449
+ "attention_matrix_shape": similarity_data['attention_matrix'].shape,
450
+ "top_correspondences": similarity_data['top_correspondences'][:10] # Top 10
451
+ }
452
+
453
+ # Add visualizations if requested
454
+ if include_visualizations:
455
+ try:
456
+ attention_heatmap = analyzer.visualize_attention_heatmap(
457
+ query_image, candidate_image, similarity_data
458
+ )
459
+ top_correspondences_viz = analyzer.visualize_top_correspondences(
460
+ query_image, candidate_image, similarity_data
461
+ )
462
+
463
+ result["visualizations"] = {
464
+ "attention_heatmap": f"data:image/png;base64,{attention_heatmap}",
465
+ "top_correspondences": f"data:image/png;base64,{top_correspondences_viz}"
466
+ }
467
+ except Exception as e:
468
+ logger.warning(f"Failed to generate visualizations: {e}")
469
+ result["visualizations"] = None
470
+
471
+ return result
472
+
473
+ except Exception as e:
474
+ logger.error(f"Error analyzing patch attention: {e}")
475
+ raise HTTPException(status_code=500, detail=str(e))
476
+
477
+
478
+ @app.get("/models")
479
+ async def get_available_models():
480
+ """Get list of available embedding models and their configurations."""
481
+ models = EmbeddingModelFactory.get_available_models()
482
+ configs = get_default_model_configs()
483
+ return {
484
+ "available_models": models,
485
+ "model_configs": configs
486
+ }
487
+
488
+
489
+ @app.get("/health")
490
+ async def health_check():
491
+ return {"status": "healthy"}
492
+
493
+
494
+ if __name__ == "__main__":
495
+ import uvicorn
496
+
497
+ uvicorn.run(app, host="0.0.0.0", port=8000)
patch_attention.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib
4
+ matplotlib.use('Agg') # Use non-interactive backend for server environments
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ from typing import Tuple, Dict, Any
8
+ import io
9
+ import base64
10
+ import math
11
+
12
+
13
+ class PatchAttentionAnalyzer:
14
+ """Utility class for computing and visualizing patch-level attention between images."""
15
+
16
+ def __init__(self, embedding_model):
17
+ self.embedding_model = embedding_model
18
+
19
+ def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
20
+ """
21
+ Compute patch-level similarities between query and candidate images.
22
+
23
+ Returns:
24
+ Dictionary containing attention matrix, top correspondences, and metadata
25
+ """
26
+ try:
27
+ # Get patch features for both images
28
+ query_patches = self.embedding_model.encode_image_patches(query_image)
29
+ candidate_patches = self.embedding_model.encode_image_patches(candidate_image)
30
+
31
+ # Compute attention matrix
32
+ attention_matrix = self.embedding_model.compute_patch_attention(query_patches, candidate_patches)
33
+
34
+ # Get grid dimensions (assuming square patches for ViT models)
35
+ query_grid_size = int(math.sqrt(query_patches.shape[0]))
36
+ candidate_grid_size = int(math.sqrt(candidate_patches.shape[0]))
37
+
38
+ # Find top correspondences for each query patch
39
+ top_correspondences = []
40
+ for i in range(attention_matrix.shape[0]):
41
+ patch_similarities = attention_matrix[i]
42
+ top_indices = torch.topk(patch_similarities, k=min(5, patch_similarities.shape[0]))
43
+
44
+ top_correspondences.append({
45
+ 'query_patch_idx': i,
46
+ 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size),
47
+ 'top_candidate_indices': top_indices.indices.tolist(),
48
+ 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size)
49
+ for idx in top_indices.indices],
50
+ 'similarity_scores': top_indices.values.tolist()
51
+ })
52
+
53
+ return {
54
+ 'attention_matrix': attention_matrix.cpu().numpy(),
55
+ 'query_grid_size': query_grid_size,
56
+ 'candidate_grid_size': candidate_grid_size,
57
+ 'top_correspondences': top_correspondences,
58
+ 'query_patches_shape': query_patches.shape,
59
+ 'candidate_patches_shape': candidate_patches.shape,
60
+ 'overall_similarity': torch.mean(attention_matrix).item()
61
+ }
62
+
63
+ except NotImplementedError:
64
+ raise ValueError(f"Patch-level encoding not supported for {self.embedding_model.get_model_name()}")
65
+ except Exception as e:
66
+ raise RuntimeError(f"Error computing patch similarities: {e}")
67
+
68
+ def _patch_idx_to_coord(self, patch_idx: int, grid_size: int) -> Tuple[int, int]:
69
+ """Convert flat patch index to (row, col) coordinate."""
70
+ row = patch_idx // grid_size
71
+ col = patch_idx % grid_size
72
+ return (row, col)
73
+
74
+ def visualize_attention_heatmap(self, query_image: Image.Image, candidate_image: Image.Image,
75
+ similarity_data: Dict[str, Any], figsize: Tuple[int, int] = (15, 10)) -> str:
76
+ """
77
+ Create a visualization showing attention heatmap between patches.
78
+ Returns base64 encoded PNG image.
79
+ """
80
+ attention_matrix = similarity_data['attention_matrix']
81
+ query_grid_size = similarity_data['query_grid_size']
82
+ candidate_grid_size = similarity_data['candidate_grid_size']
83
+
84
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
85
+ fig.suptitle(f'Patch Attention Analysis - Overall Similarity: {similarity_data["overall_similarity"]:.3f}',
86
+ fontsize=14, fontweight='bold')
87
+
88
+ # Plot original images
89
+ axes[0, 0].imshow(query_image)
90
+ axes[0, 0].set_title('Query Image')
91
+ axes[0, 0].axis('off')
92
+ self._overlay_patch_grid(axes[0, 0], query_image.size, query_grid_size)
93
+
94
+ axes[0, 1].imshow(candidate_image)
95
+ axes[0, 1].set_title('Candidate Image')
96
+ axes[0, 1].axis('off')
97
+ self._overlay_patch_grid(axes[0, 1], candidate_image.size, candidate_grid_size)
98
+
99
+ # Plot attention matrix
100
+ im = axes[1, 0].imshow(attention_matrix, cmap='viridis', aspect='auto')
101
+ axes[1, 0].set_title('Attention Matrix')
102
+ axes[1, 0].set_xlabel('Candidate Patches')
103
+ axes[1, 0].set_ylabel('Query Patches')
104
+ plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04)
105
+
106
+ # Plot attention summary (max attention per query patch)
107
+ max_attention_per_query = np.max(attention_matrix, axis=1)
108
+ attention_grid = max_attention_per_query.reshape(query_grid_size, query_grid_size)
109
+
110
+ im2 = axes[1, 1].imshow(attention_grid, cmap='hot', interpolation='nearest')
111
+ axes[1, 1].set_title('Max Attention per Query Patch')
112
+ axes[1, 1].set_xlabel('Patch Column')
113
+ axes[1, 1].set_ylabel('Patch Row')
114
+ plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04)
115
+
116
+ plt.tight_layout()
117
+
118
+ # Convert to base64
119
+ buffer = io.BytesIO()
120
+ plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
121
+ buffer.seek(0)
122
+ plot_data = buffer.getvalue()
123
+ buffer.close()
124
+ plt.close()
125
+
126
+ return base64.b64encode(plot_data).decode()
127
+
128
+ def visualize_top_correspondences(self, query_image: Image.Image, candidate_image: Image.Image,
129
+ similarity_data: Dict[str, Any], num_top_patches: int = 6) -> str:
130
+ """
131
+ Visualize the top corresponding patches between query and candidate images.
132
+ Returns base64 encoded PNG image.
133
+ """
134
+ top_correspondences = similarity_data['top_correspondences']
135
+ query_grid_size = similarity_data['query_grid_size']
136
+ candidate_grid_size = similarity_data['candidate_grid_size']
137
+
138
+ # Sort by best similarity score
139
+ sorted_correspondences = sorted(
140
+ top_correspondences,
141
+ key=lambda x: max(x['similarity_scores']),
142
+ reverse=True
143
+ )[:num_top_patches]
144
+
145
+ fig, axes = plt.subplots(2, num_top_patches, figsize=(3*num_top_patches, 6))
146
+ fig.suptitle('Top Patch Correspondences', fontsize=14, fontweight='bold')
147
+
148
+ for i, correspondence in enumerate(sorted_correspondences):
149
+ query_coord = correspondence['query_patch_coord']
150
+ best_candidate_coord = correspondence['top_candidate_coords'][0]
151
+ best_score = correspondence['similarity_scores'][0]
152
+
153
+ # Extract and show query patch
154
+ query_patch = self._extract_patch_from_image(query_image, query_coord, query_grid_size)
155
+ axes[0, i].imshow(query_patch)
156
+ axes[0, i].set_title(f'Q-Patch {query_coord}\nScore: {best_score:.3f}')
157
+ axes[0, i].axis('off')
158
+
159
+ # Extract and show best matching candidate patch
160
+ candidate_patch = self._extract_patch_from_image(candidate_image, best_candidate_coord, candidate_grid_size)
161
+ axes[1, i].imshow(candidate_patch)
162
+ axes[1, i].set_title(f'C-Patch {best_candidate_coord}')
163
+ axes[1, i].axis('off')
164
+
165
+ plt.tight_layout()
166
+
167
+ # Convert to base64
168
+ buffer = io.BytesIO()
169
+ plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
170
+ buffer.seek(0)
171
+ plot_data = buffer.getvalue()
172
+ buffer.close()
173
+ plt.close()
174
+
175
+ return base64.b64encode(plot_data).decode()
176
+
177
+ def _overlay_patch_grid(self, ax, image_size: Tuple[int, int], grid_size: int):
178
+ """Overlay patch grid lines on image."""
179
+ width, height = image_size
180
+ patch_width = width / grid_size
181
+ patch_height = height / grid_size
182
+
183
+ # Draw vertical lines
184
+ for i in range(1, grid_size):
185
+ x = i * patch_width
186
+ ax.axvline(x=x, color='white', alpha=0.5, linewidth=1)
187
+
188
+ # Draw horizontal lines
189
+ for i in range(1, grid_size):
190
+ y = i * patch_height
191
+ ax.axhline(y=y, color='white', alpha=0.5, linewidth=1)
192
+
193
+ def _extract_patch_from_image(self, image: Image.Image, patch_coord: Tuple[int, int], grid_size: int) -> Image.Image:
194
+ """Extract a specific patch from an image based on grid coordinates."""
195
+ row, col = patch_coord
196
+ width, height = image.size
197
+
198
+ patch_width = width // grid_size
199
+ patch_height = height // grid_size
200
+
201
+ left = col * patch_width
202
+ top = row * patch_height
203
+ right = min((col + 1) * patch_width, width)
204
+ bottom = min((row + 1) * patch_height, height)
205
+
206
+ return image.crop((left, top, right, bottom))
207
+
208
+ def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]:
209
+ """Get a summary of similarity statistics."""
210
+ attention_matrix = similarity_data['attention_matrix']
211
+
212
+ return {
213
+ 'overall_similarity': similarity_data['overall_similarity'],
214
+ 'max_similarity': float(np.max(attention_matrix)),
215
+ 'min_similarity': float(np.min(attention_matrix)),
216
+ 'std_similarity': float(np.std(attention_matrix)),
217
+ 'query_patches_count': similarity_data['query_patches_shape'][0],
218
+ 'candidate_patches_count': similarity_data['candidate_patches_shape'][0],
219
+ 'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))),
220
+ 'model_name': self.embedding_model.get_model_name()
221
+ }
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core FastAPI dependencies
2
+ fastapi>=0.100.0
3
+ uvicorn[standard]>=0.20.0
4
+ python-multipart
5
+ python-dotenv
6
+
7
+ # ML and Computer Vision
8
+ torch>=2.0.0
9
+ torchvision>=0.15.0
10
+ transformers>=4.30.0
11
+ huggingface-hub>=0.15.0
12
+ open_clip_torch>=2.20.0
13
+ timm>=0.9.0
14
+
15
+ # Image processing
16
+ pillow>=10.0.0
17
+ numpy>=1.24.0
18
+ matplotlib>=3.7.0
19
+ seaborn>=0.12.0
20
+
21
+ # Web scraping and search
22
+ requests>=2.30.0
23
+ duckduckgo_search>=4.0.0
24
+ lxml>=4.9.0
25
+
26
+ # Utilities
27
+ tqdm>=4.65.0
28
+ packaging>=23.0
29
+ regex
30
+ PyYAML>=6.0