Spaces:
Running
Running
| import os | |
| import json | |
| import traceback | |
| from typing import Dict, Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoProcessor, AutoModel | |
| import gradio as gr | |
| # --- Device Setup --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # For 8-bit models, the vision dtype is handled by bitsandbytes | |
| # We still need HEAD_DTYPE for our classifier head | |
| HEAD_DTYPE = torch.float32 | |
| # --- DINOv3 Specific Constants --- | |
| DINOV3_PATCH_SIZE = 16 | |
| MAX_DINOV3_RESOLUTION = 4096 | |
| print(f"Using device: {DEVICE}") | |
| print(f"Head model dtype: {HEAD_DTYPE}") | |
| # --- Model Definitions (Copied from hybrid_model.py) --- | |
| # (RMSNorm, SwiGLUFFN, ResBlockRMS, HybridHeadModel classes are unchanged and go here) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def _norm(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class SwiGLUFFN(nn.Module): | |
| def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or int(in_features * 8 / 3 / 2 * 2 ) | |
| hidden_features = (hidden_features + 1) // 2 * 2 | |
| self.w12 = nn.Linear(in_features, hidden_features * 2, bias=False) | |
| self.act = act_layer() | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.w3 = nn.Linear(hidden_features, out_features, bias=False) | |
| self.dropout2 = nn.Dropout(dropout) | |
| def forward(self, x): | |
| gate_val, up_val = self.w12(x).chunk(2, dim=-1) | |
| x = self.dropout1(self.act(gate_val) * up_val) | |
| x = self.dropout2(self.w3(x)) | |
| return x | |
| class ResBlockRMS(nn.Module): | |
| def __init__(self, ch: int, dropout: float = 0.0, rms_norm_eps: float = 1e-6): | |
| super().__init__() | |
| self.norm = RMSNorm(ch, eps=rms_norm_eps) | |
| self.ffn = SwiGLUFFN(in_features=ch, dropout=dropout) | |
| def forward(self, x): | |
| return x + self.ffn(self.norm(x)) | |
| class HybridHeadModel(nn.Module): | |
| def __init__(self, features: int, hidden_dim: int = 1280, num_classes: int = 2, use_attention: bool = True, | |
| num_attn_heads: int = 16, attn_dropout: float = 0.1, num_res_blocks: int = 3, | |
| dropout_rate: float = 0.1, rms_norm_eps: float = 1e-6, output_mode: str = 'linear'): | |
| super().__init__() | |
| self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes | |
| self.use_attention = use_attention; self.output_mode = output_mode.lower() | |
| self.attention = None; self.norm_attn = None | |
| if self.use_attention: | |
| actual_num_heads = num_attn_heads | |
| if features % num_attn_heads != 0: | |
| possible_heads = [h for h in [1, 2, 4, 8, 16, 32] if features % h == 0] # Expanded list | |
| if not possible_heads: actual_num_heads = 1 | |
| else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads)) | |
| if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads} for features={features}") | |
| self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True) | |
| self.norm_attn = RMSNorm(features, eps=rms_norm_eps) | |
| mlp_layers = [nn.Linear(features, hidden_dim), RMSNorm(hidden_dim, eps=rms_norm_eps)] | |
| for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps)) | |
| mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps)) | |
| down_proj_hidden = hidden_dim // 2 | |
| mlp_layers.append(SwiGLUFFN(hidden_dim, hidden_features=down_proj_hidden, out_features=down_proj_hidden, dropout=dropout_rate)) | |
| mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps)) | |
| mlp_layers.append(nn.Linear(down_proj_hidden, num_classes)) | |
| self.mlp_head = nn.Sequential(*mlp_layers) | |
| def forward(self, x: torch.Tensor): | |
| if self.use_attention and self.attention is not None: | |
| x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1)) | |
| logits = self.mlp_head(x.to(HEAD_DTYPE)) | |
| output_mode = self.output_mode | |
| if output_mode == 'linear': output = logits | |
| elif output_mode == 'sigmoid': output = torch.sigmoid(logits) | |
| elif output_mode == 'softmax': output = F.softmax(logits, dim=-1) | |
| elif output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0 | |
| else: raise RuntimeError(f"Invalid output_mode '{output_mode}'.") | |
| if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1) | |
| return output | |
| # --- Model Catalog --- | |
| MODEL_CATALOG = { | |
| "AnatomyFlaws-v15.5 (DINOv3 7b bf16)": { | |
| "repo_id": "Enferlain/lumi-classifier", | |
| "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json", | |
| "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors", | |
| # Explicitly define the vision model repo ID to prevent errors | |
| # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit" bnb 8bit | |
| # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4", int4 | |
| "vision_model_repo_id": "PIA-SPACE-LAB/dinov3-vit7b16-pretrain-lvd1689m", | |
| }, | |
| "AnatomyFlaws-v14.7 (SigLIP naflex)": { | |
| "repo_id": "Enferlain/lumi-classifier", | |
| "config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json", | |
| "head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors", | |
| # The base SigLIP model is not custom, so we use its official ID | |
| "vision_model_repo_id": "google/siglip2-so400m-patch16-naflex" | |
| }, | |
| } | |
| # --- Model Manager Class --- | |
| class ModelManager: | |
| def __init__(self, catalog: Dict[str, Dict[str, str]]): | |
| self.catalog = catalog | |
| self.current_model_name: str = None | |
| self.vision_model: nn.Module = None | |
| self.hf_processor: Any = None | |
| self.head_model: HybridHeadModel = None | |
| self.labels: Dict[int, str] = None | |
| self.config: Dict[str, Any] = None | |
| def load_model(self, model_name: str): | |
| if model_name == self.current_model_name: | |
| return | |
| if model_name not in self.catalog: | |
| raise ValueError(f"Model '{model_name}' not found.") | |
| print(f"Switching to model: {model_name}...") | |
| model_info = self.catalog[model_name] | |
| repo_id = model_info["repo_id"] | |
| config_filename = model_info["config_filename"] | |
| head_filename = model_info["head_filename"] | |
| vision_model_repo_id = model_info["vision_model_repo_id"] | |
| try: | |
| config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| self.config = json.load(f) | |
| print(f"Loading vision model: {vision_model_repo_id}") | |
| self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True) | |
| # --- UPDATED: CPU-compatible loading logic --- | |
| if DEVICE == "cpu": | |
| # For CPU, load unquantized model with BF16 (original format) | |
| print("Loading unquantized model for CPU...") | |
| try: | |
| self.vision_model = AutoModel.from_pretrained( | |
| vision_model_repo_id, | |
| torch_dtype=torch.bfloat16, # Keep original BF16 format | |
| device_map={"": "cpu"}, # Force CPU device mapping | |
| trust_remote_code=True | |
| ).eval() | |
| print("Successfully loaded model in BF16 format.") | |
| except Exception as bf16_error: | |
| print(f"BF16 loading failed: {bf16_error}") | |
| print("Falling back to FP32...") | |
| self.vision_model = AutoModel.from_pretrained( | |
| vision_model_repo_id, | |
| torch_dtype=torch.float32, # Fallback to FP32 | |
| device_map={"": "cpu"}, | |
| trust_remote_code=True | |
| ).eval() | |
| print("Successfully loaded model in FP32 format.") | |
| else: | |
| # For GPU environments (unchanged) | |
| self.vision_model = AutoModel.from_pretrained( | |
| vision_model_repo_id, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 | |
| ).to(DEVICE).eval() | |
| # Load classifier head (unchanged) | |
| head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename) | |
| print(f"Loading head model: {head_filename}") | |
| state_dict = load_file(head_model_path, device='cpu') | |
| head_params = self.config.get("predictor_params", self.config) | |
| self.head_model = HybridHeadModel( | |
| features=head_params.get("features"), | |
| hidden_dim=head_params.get("hidden_dim"), | |
| num_classes=self.config.get("num_classes"), | |
| use_attention=head_params.get("use_attention"), | |
| num_attn_heads=head_params.get("num_attn_heads"), | |
| attn_dropout=head_params.get("attn_dropout"), | |
| num_res_blocks=head_params.get("num_res_blocks"), | |
| dropout_rate=head_params.get("dropout_rate"), | |
| output_mode=head_params.get("output_mode", "linear") | |
| ) | |
| self.head_model.load_state_dict(state_dict, strict=True) | |
| self.head_model.to(DEVICE).eval() | |
| raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'}) | |
| self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()} | |
| self.current_model_name = model_name | |
| print(f"Successfully loaded '{model_name}'.") | |
| except Exception as e: | |
| self.current_model_name = None | |
| raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}") | |
| # --- Global Model Manager Instance --- | |
| model_manager = ModelManager(MODEL_CATALOG) | |
| # --- Prediction Function (v3 from before) --- | |
| def predict_anatomy_v3(image: Image.Image, model_name: str): | |
| if image is None: | |
| return {"Error": 1.0, "Info": 0.0} # Return numeric values | |
| try: | |
| model_manager.load_model(model_name) | |
| pil_image = image.convert("RGB") | |
| emb = None | |
| with torch.no_grad(): | |
| base_model_type = model_manager.config.get("base_vision_model", "") | |
| if "dinov3" in base_model_type.lower(): | |
| current_w, current_h = pil_image.size | |
| img_to_process = pil_image | |
| if max(current_w, current_h) > MAX_DINOV3_RESOLUTION: | |
| scale = MAX_DINOV3_RESOLUTION / max(current_w, current_h) | |
| current_w, current_h = int(current_w * scale), int(current_h * scale) | |
| img_to_process = pil_image.resize((current_w, current_h), Image.Resampling.LANCZOS) | |
| new_w = ((current_w + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE | |
| new_h = ((current_h + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE | |
| if new_w != current_w or new_h != current_h: | |
| img_to_process = img_to_process.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| inputs = model_manager.hf_processor(images=[img_to_process], return_tensors="pt") | |
| # For 8-bit, send inputs to the same device as the model | |
| pixel_values = inputs.pixel_values.to(model_manager.vision_model.device) | |
| outputs = model_manager.vision_model(pixel_values=pixel_values) | |
| last_hidden_state = outputs.last_hidden_state | |
| nreg = getattr(model_manager.vision_model.config, 'num_register_tokens', 0) | |
| patch_embeddings = last_hidden_state[:, 1 + nreg:] | |
| emb = torch.mean(patch_embeddings, dim=1) | |
| elif "siglip" in base_model_type.lower(): | |
| inputs = model_manager.hf_processor(images=[pil_image], return_tensors="pt") | |
| pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=torch.float16) | |
| if "naflex" in base_model_type.lower(): | |
| attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE) | |
| spatial_shapes = inputs.get("spatial_shapes") | |
| model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask, | |
| "spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)} | |
| vision_model_component = getattr(model_manager.vision_model, 'vision_model', model_manager.vision_model) | |
| emb = vision_model_component(**model_call_kwargs).pooler_output | |
| else: emb = model_manager.vision_model.get_image_features(pixel_values=pixel_values) | |
| else: raise ValueError(f"Unknown base model type for embedding: {base_model_type}") | |
| if emb is None: raise ValueError("Failed to get embedding.") | |
| norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8) | |
| emb_normalized = emb / norm.to(emb.dtype) | |
| with torch.no_grad(): | |
| prediction = model_manager.head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE)) | |
| output_probs = {} | |
| if model_manager.head_model.num_classes == 2: | |
| probs = F.softmax(prediction.squeeze().float(), dim=-1) | |
| output_probs[model_manager.labels[0]] = probs[0].item() | |
| output_probs[model_manager.labels[1]] = probs[1].item() | |
| else: | |
| prob_good = torch.sigmoid(prediction.squeeze()).item() | |
| output_probs[model_manager.labels[0]] = 1.0 - prob_good | |
| output_probs[model_manager.labels[1]] = prob_good | |
| return output_probs | |
| except Exception as e: | |
| print(f"Error during prediction: {e}\n{traceback.format_exc()}") | |
| # Return properly formatted error for Gradio Label | |
| error_msg = str(e)[:50] + "..." if len(str(e)) > 50 else str(e) | |
| return { | |
| f"Error: {error_msg}": 1.0, | |
| "Please check logs": 0.0 | |
| } | |
| # --- Gradio Interface --- | |
| DESCRIPTION = """ | |
| ## Lumi's Anatomy Flaw Classifier Demo ✨ | |
| Select a model from the dropdown, then upload an image to classify its anatomy/structural correctness. | |
| Will be slow since it runs on cpu, ~2minutes on dinov3. | |
| """ | |
| EXAMPLE_DIR = "examples" | |
| default_model = list(MODEL_CATALOG.keys())[0] | |
| # 1. Find the paths to our example images | |
| example_paths = [] | |
| if os.path.isdir(EXAMPLE_DIR): | |
| example_paths = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] | |
| # 2. Create the nested list Gradio needs: [[image, model_name], [image, model_name], ...] | |
| examples_nested = [] | |
| if example_paths: | |
| examples_nested = [[path, default_model] for path in example_paths] | |
| # 3. Create the interface, passing the correctly formatted list | |
| interface = gr.Interface( | |
| fn=predict_anatomy_v3, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model") | |
| ], | |
| outputs=gr.Label(label="Class Probabilities", num_top_classes=2), | |
| title="Lumi's Anatomy Classifier", | |
| description=DESCRIPTION, | |
| examples=examples_nested if examples_nested else None, # Pass the new nested list | |
| allow_flagging="never", | |
| cache_examples=True | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| print("Pre-loading default model...") | |
| model_manager.load_model(default_model) | |
| except Exception as e: | |
| print(f"WARNING: Could not pre-load default model. Error: {e}") | |
| interface.launch() |