# app.py import os, io, traceback from typing import Optional, List, Tuple import torch from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import JSONResponse from PIL import Image, UnidentifiedImageError, ImageFile from torchvision import transforms as T from functools import lru_cache ImageFile.LOAD_TRUNCATED_IMAGES = True CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache") os.environ["XDG_CACHE_HOME"] = CACHE_ROOT os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf") os.environ["HUGGINGFACE_HUB_CACHE"] = os.environ["HF_HOME"] os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] os.environ["OPENCLIP_CACHE_DIR"] = os.path.join(CACHE_ROOT, "open_clip") os.environ["TORCH_HOME"] = os.path.join(CACHE_ROOT, "torch") os.makedirs(os.environ["HF_HOME"], exist_ok=True) os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True) os.makedirs(os.environ["TORCH_HOME"], exist_ok=True) import open_clip # importar despues de ajustar caches # ===== limites basicos ===== NUM_THREADS = int(os.environ.get("NUM_THREADS", "1")) torch.set_num_threads(NUM_THREADS) os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS) os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS) try: torch.set_num_interop_threads(1) except Exception: pass DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 if DEVICE == "cuda": torch.set_float32_matmul_precision("high") # ===== rutas a embeddings ===== MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt") VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt") # ===== modelo PE bigG ===== MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448" PRETRAINED = None app = FastAPI(title="OpenCLIP PE bigG Vehicle API") # ===== modelo / preprocess ===== _ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED) # versiones de open_clip devuelven (model, preprocess_train, preprocess_val) if isinstance(_ret, tuple) and len(_ret) == 3: clip_model, _preprocess_train, preprocess = _ret else: clip_model, preprocess = _ret clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval() for p in clip_model.parameters(): p.requires_grad = False normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize)) SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None) if isinstance(SIZE, (tuple, list)): SIZE = max(SIZE) if SIZE is None: SIZE = 448 # PE bigG es 448; fallback transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)]) # ===== utils imagen (sin cambios: letterbox + BICUBIC) ===== def resize_letterbox(img: Image.Image, size: int) -> Image.Image: if img.mode != "RGB": img = img.convert("RGB") w, h = img.size if w == 0 or h == 0: raise UnidentifiedImageError("imagen invalida") scale = size / max(w, h) nw, nh = max(1, int(w*scale)), max(1, int(h*scale)) img_resized = img.resize((nw, nh), Image.BICUBIC) canvas = Image.new("RGB", (size, size), (0, 0, 0)) canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2)) return canvas # ===== cargar embeddings (sin cambios) ===== def _ensure_label_list(x): if isinstance(x, (list, tuple)): return list(x) if hasattr(x, "tolist"): return [str(s) for s in x.tolist()] return [str(s) for s in x] def _load_embeddings(path: str): ckpt = torch.load(path, map_location="cpu") labels = _ensure_label_list(ckpt["labels"]) embeds = ckpt["embeddings"].to("cpu") embeds = embeds / embeds.norm(dim=-1, keepdim=True) return labels, embeds model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) # comprobar dimension (PE bigG mantiene 1280) with torch.inference_mode(): dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE) img_dim = clip_model.encode_image(dummy).shape[-1] if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim: raise RuntimeError( f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, " f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}." ) _versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {} def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]: hit = _versions_cache.get(modelo_full) if hit is not None: return hit idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)] if not idxs: _versions_cache[modelo_full] = ([], None) return _versions_cache[modelo_full] labels_sub = [version_labels[i] for i in idxs] embeds_sub = version_embeddings[idxs] # copia de esas filas _versions_cache[modelo_full] = (labels_sub, embeds_sub) return _versions_cache[modelo_full] # ===== inferencia (sin cambios de logica/precision) ===== @torch.inference_mode() def _encode_pil(img: Image.Image) -> torch.Tensor: img = resize_letterbox(img, SIZE) tensor = transform(img).unsqueeze(0).to(device=DEVICE) if DEVICE == "cuda": tensor = tensor.to(dtype=DTYPE) feats = clip_model.encode_image(tensor) return feats / feats.norm(dim=-1, keepdim=True) def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1): sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0] vals, idxs = torch.topk(sim, k=k) conf = torch.softmax(vals, dim=0) return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)] def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None): if not front_bytes or len(front_bytes) < 128: raise UnidentifiedImageError("imagen invalida") img_front = Image.open(io.BytesIO(front_bytes)) img_feat = _encode_pil(img_front) # paso 1: modelo top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0] modelo_full = top_model["label"] partes = modelo_full.split(" ", 1) marca = partes[0] if len(partes) >= 1 else "" modelo = partes[1] if len(partes) == 2 else "" # paso 2: versiones con cache labels_sub, embeds_sub = _get_versions_subset(modelo_full) if not labels_sub: return {"brand": marca.upper(), "model": modelo.title(), "version": ""} # paso 3: version top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0] raw = top_ver["label"] prefix = modelo_full + " " ver = raw[len(prefix):] if raw.startswith(prefix) else raw ver = ver.split(" ")[0] if top_ver["confidence"] < 30.0: ver = "" return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""} # ===== endpoints ===== @app.get("/") def root(): return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS} @app.post("/predict/") async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None): try: if front is None: return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200) front_bytes = await front.read() back_bytes = await back.read() if back is not None else None vehicle = process_image_bytes(front_bytes, back_bytes) return JSONResponse(content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200) except (UnidentifiedImageError, OSError, RuntimeError, ValueError) as e: return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200) except Exception: traceback.print_exc() return JSONResponse(content={"code": 404, "data": {}, "error": "internal"}, status_code=200)