import base64 import io import cv2 import numpy as np import torch from fastapi import FastAPI from fastapi.responses import FileResponse from pydantic import BaseModel from PIL import Image import segmentation_models_pytorch as smp from huggingface_hub import hf_hub_download # --- CONFIGURATION --- HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2" DEVICE = "cpu" IMG_SIZE = 256 CONFIDENCE_THRESHOLD = 0.298 # --- DATA MODELS FOR API --- class InferenceRequest(BaseModel): image: str scribble_mask: str class InferenceResponse(BaseModel): predicted_mask: str # --- INITIALIZE FASTAPI APP --- app = FastAPI() # --- LOAD MODEL ON STARTUP --- def load_model(): print(f"Loading model '{HF_MODEL_REPO_ID}'...") try: model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth") except Exception as e: # Fallback for local testing if the model file is in the same directory if os.path.exists("best_model.pth"): print("Could not download from Hub, using local 'best_model.pth'.") model_path = "best_model.pth" else: raise e model = smp.Unet( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=4, classes=1, ) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.to(DEVICE) model.eval() print("Model loaded successfully.") return model model = load_model() # --- HELPER FUNCTIONS --- def base64_to_cv2_rgba(base64_string: str): header, encoded = base64_string.split(",", 1) img_data = base64.b64decode(encoded) pil_image = Image.open(io.BytesIO(img_data)) return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA) def cv2_to_base64(image: np.ndarray): _, buffer = cv2.imencode('.png', image) png_as_text = base64.b64encode(buffer).decode('utf-8') return f"data:image/png;base64,{png_as_text}" # --- API ENDPOINTS --- @app.get("/") def read_root(): return FileResponse('index.html') @app.post("/predict", response_model=InferenceResponse) async def predict(request: InferenceRequest): image_cv = base64_to_cv2_rgba(request.image) scribble_cv = base64_to_cv2_rgba(request.scribble_mask) if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1: scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY) h, w, _ = image_cv.shape image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST) image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0 scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0 input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(input_tensor) probs = torch.sigmoid(output) binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy() output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST) output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8) result_base64 = cv2_to_base64(output_mask_uint8) return InferenceResponse(predicted_mask=result_base64)