Enferlain's picture
Update app.py
53532e2 verified
import os
import json
import traceback
from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModel
import gradio as gr
# --- Device Setup ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# For 8-bit models, the vision dtype is handled by bitsandbytes
# We still need HEAD_DTYPE for our classifier head
HEAD_DTYPE = torch.float32
# --- DINOv3 Specific Constants ---
DINOV3_PATCH_SIZE = 16
MAX_DINOV3_RESOLUTION = 4096
print(f"Using device: {DEVICE}")
print(f"Head model dtype: {HEAD_DTYPE}")
# --- Model Definitions (Copied from hybrid_model.py) ---
# (RMSNorm, SwiGLUFFN, ResBlockRMS, HybridHeadModel classes are unchanged and go here)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class SwiGLUFFN(nn.Module):
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or int(in_features * 8 / 3 / 2 * 2 )
hidden_features = (hidden_features + 1) // 2 * 2
self.w12 = nn.Linear(in_features, hidden_features * 2, bias=False)
self.act = act_layer()
self.dropout1 = nn.Dropout(dropout)
self.w3 = nn.Linear(hidden_features, out_features, bias=False)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
gate_val, up_val = self.w12(x).chunk(2, dim=-1)
x = self.dropout1(self.act(gate_val) * up_val)
x = self.dropout2(self.w3(x))
return x
class ResBlockRMS(nn.Module):
def __init__(self, ch: int, dropout: float = 0.0, rms_norm_eps: float = 1e-6):
super().__init__()
self.norm = RMSNorm(ch, eps=rms_norm_eps)
self.ffn = SwiGLUFFN(in_features=ch, dropout=dropout)
def forward(self, x):
return x + self.ffn(self.norm(x))
class HybridHeadModel(nn.Module):
def __init__(self, features: int, hidden_dim: int = 1280, num_classes: int = 2, use_attention: bool = True,
num_attn_heads: int = 16, attn_dropout: float = 0.1, num_res_blocks: int = 3,
dropout_rate: float = 0.1, rms_norm_eps: float = 1e-6, output_mode: str = 'linear'):
super().__init__()
self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes
self.use_attention = use_attention; self.output_mode = output_mode.lower()
self.attention = None; self.norm_attn = None
if self.use_attention:
actual_num_heads = num_attn_heads
if features % num_attn_heads != 0:
possible_heads = [h for h in [1, 2, 4, 8, 16, 32] if features % h == 0] # Expanded list
if not possible_heads: actual_num_heads = 1
else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads))
if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads} for features={features}")
self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True)
self.norm_attn = RMSNorm(features, eps=rms_norm_eps)
mlp_layers = [nn.Linear(features, hidden_dim), RMSNorm(hidden_dim, eps=rms_norm_eps)]
for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps))
mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps))
down_proj_hidden = hidden_dim // 2
mlp_layers.append(SwiGLUFFN(hidden_dim, hidden_features=down_proj_hidden, out_features=down_proj_hidden, dropout=dropout_rate))
mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps))
mlp_layers.append(nn.Linear(down_proj_hidden, num_classes))
self.mlp_head = nn.Sequential(*mlp_layers)
def forward(self, x: torch.Tensor):
if self.use_attention and self.attention is not None:
x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1))
logits = self.mlp_head(x.to(HEAD_DTYPE))
output_mode = self.output_mode
if output_mode == 'linear': output = logits
elif output_mode == 'sigmoid': output = torch.sigmoid(logits)
elif output_mode == 'softmax': output = F.softmax(logits, dim=-1)
elif output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0
else: raise RuntimeError(f"Invalid output_mode '{output_mode}'.")
if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1)
return output
# --- Model Catalog ---
MODEL_CATALOG = {
"AnatomyFlaws-v15.5 (DINOv3 7b bf16)": {
"repo_id": "Enferlain/lumi-classifier",
"config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
"head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors",
# Explicitly define the vision model repo ID to prevent errors
# "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit" bnb 8bit
# "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4", int4
"vision_model_repo_id": "PIA-SPACE-LAB/dinov3-vit7b16-pretrain-lvd1689m",
},
"AnatomyFlaws-v14.7 (SigLIP naflex)": {
"repo_id": "Enferlain/lumi-classifier",
"config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json",
"head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors",
# The base SigLIP model is not custom, so we use its official ID
"vision_model_repo_id": "google/siglip2-so400m-patch16-naflex"
},
}
# --- Model Manager Class ---
class ModelManager:
def __init__(self, catalog: Dict[str, Dict[str, str]]):
self.catalog = catalog
self.current_model_name: str = None
self.vision_model: nn.Module = None
self.hf_processor: Any = None
self.head_model: HybridHeadModel = None
self.labels: Dict[int, str] = None
self.config: Dict[str, Any] = None
def load_model(self, model_name: str):
if model_name == self.current_model_name:
return
if model_name not in self.catalog:
raise ValueError(f"Model '{model_name}' not found.")
print(f"Switching to model: {model_name}...")
model_info = self.catalog[model_name]
repo_id = model_info["repo_id"]
config_filename = model_info["config_filename"]
head_filename = model_info["head_filename"]
vision_model_repo_id = model_info["vision_model_repo_id"]
try:
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
print(f"Loading vision model: {vision_model_repo_id}")
self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True)
# --- UPDATED: CPU-compatible loading logic ---
if DEVICE == "cpu":
# For CPU, load unquantized model with BF16 (original format)
print("Loading unquantized model for CPU...")
try:
self.vision_model = AutoModel.from_pretrained(
vision_model_repo_id,
torch_dtype=torch.bfloat16, # Keep original BF16 format
device_map={"": "cpu"}, # Force CPU device mapping
trust_remote_code=True
).eval()
print("Successfully loaded model in BF16 format.")
except Exception as bf16_error:
print(f"BF16 loading failed: {bf16_error}")
print("Falling back to FP32...")
self.vision_model = AutoModel.from_pretrained(
vision_model_repo_id,
torch_dtype=torch.float32, # Fallback to FP32
device_map={"": "cpu"},
trust_remote_code=True
).eval()
print("Successfully loaded model in FP32 format.")
else:
# For GPU environments (unchanged)
self.vision_model = AutoModel.from_pretrained(
vision_model_repo_id,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE).eval()
# Load classifier head (unchanged)
head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
print(f"Loading head model: {head_filename}")
state_dict = load_file(head_model_path, device='cpu')
head_params = self.config.get("predictor_params", self.config)
self.head_model = HybridHeadModel(
features=head_params.get("features"),
hidden_dim=head_params.get("hidden_dim"),
num_classes=self.config.get("num_classes"),
use_attention=head_params.get("use_attention"),
num_attn_heads=head_params.get("num_attn_heads"),
attn_dropout=head_params.get("attn_dropout"),
num_res_blocks=head_params.get("num_res_blocks"),
dropout_rate=head_params.get("dropout_rate"),
output_mode=head_params.get("output_mode", "linear")
)
self.head_model.load_state_dict(state_dict, strict=True)
self.head_model.to(DEVICE).eval()
raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'})
self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()}
self.current_model_name = model_name
print(f"Successfully loaded '{model_name}'.")
except Exception as e:
self.current_model_name = None
raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}")
# --- Global Model Manager Instance ---
model_manager = ModelManager(MODEL_CATALOG)
# --- Prediction Function (v3 from before) ---
def predict_anatomy_v3(image: Image.Image, model_name: str):
if image is None:
return {"Error": 1.0, "Info": 0.0} # Return numeric values
try:
model_manager.load_model(model_name)
pil_image = image.convert("RGB")
emb = None
with torch.no_grad():
base_model_type = model_manager.config.get("base_vision_model", "")
if "dinov3" in base_model_type.lower():
current_w, current_h = pil_image.size
img_to_process = pil_image
if max(current_w, current_h) > MAX_DINOV3_RESOLUTION:
scale = MAX_DINOV3_RESOLUTION / max(current_w, current_h)
current_w, current_h = int(current_w * scale), int(current_h * scale)
img_to_process = pil_image.resize((current_w, current_h), Image.Resampling.LANCZOS)
new_w = ((current_w + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
new_h = ((current_h + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
if new_w != current_w or new_h != current_h:
img_to_process = img_to_process.resize((new_w, new_h), Image.Resampling.LANCZOS)
inputs = model_manager.hf_processor(images=[img_to_process], return_tensors="pt")
# For 8-bit, send inputs to the same device as the model
pixel_values = inputs.pixel_values.to(model_manager.vision_model.device)
outputs = model_manager.vision_model(pixel_values=pixel_values)
last_hidden_state = outputs.last_hidden_state
nreg = getattr(model_manager.vision_model.config, 'num_register_tokens', 0)
patch_embeddings = last_hidden_state[:, 1 + nreg:]
emb = torch.mean(patch_embeddings, dim=1)
elif "siglip" in base_model_type.lower():
inputs = model_manager.hf_processor(images=[pil_image], return_tensors="pt")
pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=torch.float16)
if "naflex" in base_model_type.lower():
attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE)
spatial_shapes = inputs.get("spatial_shapes")
model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask,
"spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)}
vision_model_component = getattr(model_manager.vision_model, 'vision_model', model_manager.vision_model)
emb = vision_model_component(**model_call_kwargs).pooler_output
else: emb = model_manager.vision_model.get_image_features(pixel_values=pixel_values)
else: raise ValueError(f"Unknown base model type for embedding: {base_model_type}")
if emb is None: raise ValueError("Failed to get embedding.")
norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8)
emb_normalized = emb / norm.to(emb.dtype)
with torch.no_grad():
prediction = model_manager.head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE))
output_probs = {}
if model_manager.head_model.num_classes == 2:
probs = F.softmax(prediction.squeeze().float(), dim=-1)
output_probs[model_manager.labels[0]] = probs[0].item()
output_probs[model_manager.labels[1]] = probs[1].item()
else:
prob_good = torch.sigmoid(prediction.squeeze()).item()
output_probs[model_manager.labels[0]] = 1.0 - prob_good
output_probs[model_manager.labels[1]] = prob_good
return output_probs
except Exception as e:
print(f"Error during prediction: {e}\n{traceback.format_exc()}")
# Return properly formatted error for Gradio Label
error_msg = str(e)[:50] + "..." if len(str(e)) > 50 else str(e)
return {
f"Error: {error_msg}": 1.0,
"Please check logs": 0.0
}
# --- Gradio Interface ---
DESCRIPTION = """
## Lumi's Anatomy Flaw Classifier Demo ✨
Select a model from the dropdown, then upload an image to classify its anatomy/structural correctness.
Will be slow since it runs on cpu, ~2minutes on dinov3.
"""
EXAMPLE_DIR = "examples"
default_model = list(MODEL_CATALOG.keys())[0]
# 1. Find the paths to our example images
example_paths = []
if os.path.isdir(EXAMPLE_DIR):
example_paths = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
# 2. Create the nested list Gradio needs: [[image, model_name], [image, model_name], ...]
examples_nested = []
if example_paths:
examples_nested = [[path, default_model] for path in example_paths]
# 3. Create the interface, passing the correctly formatted list
interface = gr.Interface(
fn=predict_anatomy_v3,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model")
],
outputs=gr.Label(label="Class Probabilities", num_top_classes=2),
title="Lumi's Anatomy Classifier",
description=DESCRIPTION,
examples=examples_nested if examples_nested else None, # Pass the new nested list
allow_flagging="never",
cache_examples=True
)
if __name__ == "__main__":
try:
print("Pre-loading default model...")
model_manager.load_model(default_model)
except Exception as e:
print(f"WARNING: Could not pre-load default model. Error: {e}")
interface.launch()