Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from transformers import AutoImageProcessor, AutoModel | |
| import torch.nn.functional as F | |
| import spaces | |
| DINO_MODELS = { | |
| "DINOv3 Base ViT": "facebook/dinov3-vitb16-pretrain-lvd1689m", | |
| "DINOv3 Large ViT": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
| "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m", | |
| } | |
| _default_model_name = "DINOv3 Base ViT" | |
| processor = AutoImageProcessor.from_pretrained(DINO_MODELS[_default_model_name]) | |
| def load_model(model_name): | |
| global processor | |
| model_path = DINO_MODELS[model_name] | |
| processor = AutoImageProcessor.from_pretrained(model_path) | |
| return f"β Model '{model_name}' loaded successfully!" | |
| def extract_features(image, model_name): | |
| model_id = DINO_MODELS[model_name] | |
| model = AutoModel.from_pretrained(model_id).to("cuda").eval() | |
| local_processor = AutoImageProcessor.from_pretrained(model_id) | |
| inputs = local_processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| model_size = local_processor.size["height"] | |
| original_size = image.size | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| features = outputs.last_hidden_state | |
| num_register_tokens = getattr(model.config, "num_register_tokens", 0) | |
| return features[:, 1 + num_register_tokens:, :].float().cpu(), original_size, model_size | |
| def find_correspondences(features1, features2, threshold=0.8): | |
| device = torch.device("cpu") | |
| B, N1, D = features1.shape | |
| _, N2, _ = features2.shape | |
| features1_norm = F.normalize(features1, dim=-1) | |
| features2_norm = F.normalize(features2, dim=-1) | |
| similarity = torch.matmul(features1_norm, features2_norm.transpose(-2, -1)) | |
| matches1 = torch.argmax(similarity, dim=-1) | |
| matches2 = torch.argmax(similarity, dim=-2) | |
| max_sim1 = torch.max(similarity, dim=-1)[0] | |
| arange1 = torch.arange(N1, device=device) | |
| mutual_matches = matches2[0, matches1[0]] == arange1 | |
| good_matches = (max_sim1[0] > threshold) & mutual_matches | |
| return matches1[0][good_matches].cpu(), arange1[good_matches].cpu(), max_sim1[0][good_matches].cpu() | |
| def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14): | |
| orig_w, orig_h = original_size | |
| patches_h = model_size // patch_size | |
| patches_w = model_size // patch_size | |
| if patch_idx >= patches_h * patches_w: | |
| return None, None | |
| patch_y = patch_idx // patches_w | |
| patch_x = patch_idx % patches_w | |
| y_model = patch_y * patch_size + patch_size // 2 | |
| x_model = patch_x * patch_size + patch_size // 2 | |
| x = int(x_model * orig_w / model_size) | |
| y = int(y_model * orig_h / model_size) | |
| return x, y | |
| def match_keypoints(image1, image2, model_name): | |
| if image1 is None or image2 is None: | |
| return None | |
| load_model(model_name) | |
| img1_pil = Image.fromarray(image1).convert("RGB") | |
| img2_pil = Image.fromarray(image2).convert("RGB") | |
| features1, original_size1, model_size1 = extract_features(img1_pil, model_name) | |
| features2, original_size2, model_size2 = extract_features(img2_pil, model_name) | |
| matches2_idx, matches1_idx, similarities = find_correspondences(features1, features2, threshold=0.7) | |
| img1_np = np.array(img1_pil) | |
| img2_np = np.array(img2_pil) | |
| h1, w1 = img1_np.shape[:2] | |
| h2, w2 = img2_np.shape[:2] | |
| result_img = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8) | |
| result_img[:h1, :w1] = img1_np | |
| result_img[:h2, w1:w1 + w2] = img2_np | |
| for m1, m2, _ in zip(matches1_idx, matches2_idx, similarities): | |
| x1, y1 = patch_to_image_coords(int(m1), original_size1, model_size1) | |
| x2, y2 = patch_to_image_coords(int(m2), original_size2, model_size2) | |
| if x1 is not None and x2 is not None: | |
| color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
| cv2.circle(result_img, (x1, y1), 6, color, -1) | |
| cv2.circle(result_img, (x2 + w1, y2), 6, color, -1) | |
| cv2.line(result_img, (x1, y1), (x2 + w1, y2), color, 2) | |
| return result_img | |
| with gr.Blocks(title="DINOv3 Keypoint Matching") as demo: | |
| gr.Markdown("# DINOv3 For Keypoint Matching") | |
| gr.Markdown("DINOv3 can be used to find matching features between two images.") | |
| gr.Markdown("Upload two images to find corresponding keypoints using DINOv3 features, switch between different DINOv3 checkpoints.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(scale=1): | |
| image1 = gr.Image(label="Image 1", type="numpy") | |
| image2 = gr.Image(label="Image 2", type="numpy") | |
| model_selector = gr.Dropdown( | |
| choices=list(DINO_MODELS.keys()), | |
| value=_default_model_name, | |
| label="Select DINOv3 Model", | |
| info="Choose the model size. Larger models may provide better features but require more memory.", | |
| ) | |
| status_bar = gr.Textbox( | |
| value=f"β Model '{_default_model_name}' ready.", | |
| label="Status", | |
| interactive=False, | |
| container=False, | |
| ) | |
| match_btn = gr.Button("Find Correspondences", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Matched Keypoints") | |
| model_selector.change(fn=load_model, inputs=[model_selector], outputs=[status_bar]) | |
| match_btn.click(fn=match_keypoints, inputs=[image1, image2, model_selector], outputs=[output_image]) | |
| gr.Examples( | |
| examples=[["map.jpg", "street.jpg"], ["bee.JPG", "bee_edited.jpg"]], | |
| inputs=[image1, image2], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |