# app.py — Gradio UI with Inference, Validation, and 📈 Metrics Dashboard import os, io, json, glob os.environ.setdefault("OMP_NUM_THREADS", "1") # quiet libgomp in HF Spaces import gradio as gr import numpy as np import pandas as pd from PIL import Image import torch, torch.nn as nn from torchvision import models, transforms, datasets import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns # local utils from cam_utils import grad_cam # ---------------- Config ---------------- MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0") NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2")) IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224")) # prefer env var; fall back to best.pt -> last.pt -> any .pt in /checkpoints WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt") if not os.path.exists(WEIGHTS_PATH): candidates = ["checkpoints/best.pt", "checkpoints/last.pt"] + sorted(glob.glob("checkpoints/*.pt")) WEIGHTS_PATH = next((p for p in candidates if os.path.exists(p)), "") CLASS_NAMES = ["Parasitized", "Uninfected"] if NUM_CLASSES == 2 else [str(i) for i in range(NUM_CLASSES)] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------------- Model ---------------- def build_model(name: str, num_classes: int): name = name.lower() if name == "efficientnet_b0": m = models.efficientnet_b0(weights=None) m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes) return m elif name == "resnet50": m = models.resnet50(weights=None) m.fc = nn.Linear(m.fc.in_features, num_classes) return m else: raise ValueError(f"Unsupported model_name: {name}") if not WEIGHTS_PATH: raise FileNotFoundError( "No checkpoint found. Upload a .pt to /checkpoints and/or set WEIGHTS_PATH variable." ) _model = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE) _state = torch.load(WEIGHTS_PATH, map_location=DEVICE) # allow either plain state_dict or {"state_dict": ...} try: _model.load_state_dict(_state) except Exception: _model.load_state_dict(_state["state_dict"]) _model.eval() _pre = transforms.Compose([ transforms.Resize(int(IMAGE_SIZE*1.15)), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), ]) # ---------------- Inference ---------------- def predict(image: Image.Image, show_cam: bool): if image is None: return {"label": "", "confidences": []}, None img = image.convert("RGB") x = _pre(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = _model(x).cpu().numpy().squeeze() probs = np.exp(logits - logits.max()); probs = probs / probs.sum() pred_idx = int(np.argmax(probs)) label = CLASS_NAMES[pred_idx] # Gradio Label expects dict like {"class": prob} label_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} overlay = None if show_cam: cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=DEVICE) overlay = Image.fromarray((cam["overlay"]*255).astype("uint8")) return label_dict, overlay # ---------------- Validation (optional) ---------------- def validate(zip_or_folder): import tempfile, zipfile, shutil if zip_or_folder is None: return "Upload a .zip of your validation set.", None tmp = tempfile.mkdtemp() root = tmp # accept zip only (Spaces File component gives .name) with zipfile.ZipFile(zip_or_folder.name, 'r') as zf: zf.extractall(tmp) ds = datasets.ImageFolder(root, transform=_pre) dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2) ys, ps = [], [] with torch.no_grad(): for xb, yb in dl: preds = _model(xb.to(DEVICE)).argmax(1).cpu().numpy() ys.extend(yb.numpy()); ps.extend(preds) import sklearn.metrics as sk rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True) cm = sk.confusion_matrix(ys, ps) # plot CM fig, ax = plt.subplots(figsize=(4.5,4)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=ds.classes, yticklabels=ds.classes, ax=ax) ax.set_xlabel("Predicted"); ax.set_ylabel("True"); fig.tight_layout() buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=160); buf.seek(0) return json.dumps(rep, indent=2), buf # ---------------- Metrics Dashboard helpers ---------------- METRICS_DEFAULT = "checkpoints/metrics.csv" def _plot_line(df: pd.DataFrame, ycol: str, title: str, ylabel: str): if ycol not in df.columns: return None s = df[["epoch", ycol]].dropna() if len(s) == 0: return None fig, ax = plt.subplots(figsize=(5.5, 3.2)) ax.plot(s["epoch"], s[ycol], marker="o") ax.set_title(title); ax.set_xlabel("epoch"); ax.set_ylabel(ylabel) ax.grid(True, alpha=0.3) out = io.BytesIO(); fig.savefig(out, format="png", dpi=160, bbox_inches="tight"); out.seek(0) return out def load_metrics(path: str): if not os.path.exists(path): return "No metrics file found.", None, None, None, None, None, None try: df = pd.read_csv(path) except Exception as e: return f"Failed to read CSV: {e}", None, None, None, None, None, None # last rows table (string) tail = df.tail(10).to_markdown(index=False) fig_loss = _plot_line(df, "train_loss", "Training Loss", "loss") fig_acc = _plot_line(df, "val_acc", "Validation Accuracy", "acc") fig_act = _plot_line(df, "act_rate", "Activation Rate", "fraction") fig_save = _plot_line(df, "save_rate", "Energy Savings", "fraction") fig_thr = _plot_line(df, "threshold", "Activation Threshold", "threshold") # also give a downloadable cleaned csv (so users can grab it) csv_bytes = df.to_csv(index=False).encode("utf-8") return tail, fig_loss, fig_acc, fig_act, fig_save, fig_thr, csv_bytes def upload_metrics(file): """ Allow uploading a metrics.csv from Colab to the Space. """ if file is None: return "No file uploaded.", None, None, None, None, None, None os.makedirs("checkpoints", exist_ok=True) dst = METRICS_DEFAULT with open(dst, "wb") as f: f.write(file.read()) # reload return load_metrics(dst) # ---------------- Gradio UI ---------------- with gr.Blocks(title="Malaria Diagnostic Assistant") as demo: gr.Markdown("# 🩸 Malaria Diagnostic Assistant") gr.Markdown("Prototype — energy-efficient triage with human-in-the-loop (Adaptive Sparse Training)") with gr.Tab("🔍 Inference"): with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Upload blood smear image") show_cam = gr.Checkbox(value=True, label="Show Grad-CAM") btn_pred = gr.Button("Predict", variant="primary") with gr.Column(scale=1): label_out = gr.Label(num_top_classes=2, label="Prediction & Probabilities") cam_out = gr.Image(type="pil", label="Grad-CAM overlay") btn_pred.click(fn=predict, inputs=[img_in, show_cam], outputs=[label_out, cam_out]) with gr.Tab("✅ Validation (optional)"): gr.Markdown("Upload a **.zip** containing a folder with class subfolders (e.g., `Parasitized/`, `Uninfected/`).") val_zip = gr.File(label="Validation ZIP", file_types=[".zip"]) btn_eval = gr.Button("Compute report + confusion matrix") rep_out = gr.Textbox(label="classification_report (JSON)") cm_img = gr.Image(type="filepath", label="Confusion Matrix") btn_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img]) with gr.Tab("📈 Dashboard"): gr.Markdown("Visualize training logs from `checkpoints/metrics.csv` " "(written each epoch by your Colab training).") with gr.Row(): metrics_path = gr.Textbox(value=METRICS_DEFAULT, label="Metrics CSV path") btn_load = gr.Button("Refresh") tail_md = gr.Markdown(value="Upload metrics or click refresh.") plot_loss = gr.Image(label="Training Loss") plot_acc = gr.Image(label="Validation Accuracy") plot_act = gr.Image(label="Activation Rate") plot_save = gr.Image(label="Energy Savings") plot_thr = gr.Image(label="Activation Threshold") dl_btn = gr.DownloadButton(label="⬇️ Download metrics.csv", value=None, file_name="metrics.csv") # wire up refresh btn_load.click(fn=load_metrics, inputs=[metrics_path], outputs=[tail_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn]) gr.Markdown("---") gr.Markdown("Or upload a `metrics.csv` here to preview:") up_file = gr.File(label="Upload metrics.csv", file_types=[".csv"]) up_file.upload(fn=upload_metrics, inputs=[up_file], outputs=[tail_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn]) # optional auto-refresh every 10s auto = gr.Checkbox(label="Auto-refresh every 10s", value=False) def maybe_refresh(path, enabled): if not enabled: # return no update if disabled return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() return load_metrics(path) refresher = gr.Timer(10.0, active=auto) refresher.tick(fn=maybe_refresh, inputs=[metrics_path, auto], outputs=[tail_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn]) gr.Markdown("---") gr.Markdown(f"**Weights**: `{WEIGHTS_PATH}` • **Backbone**: `{MODEL_NAME}` • **Device**: `{DEVICE}`\n\n" "Built with EfficientNet-B0 + Adaptive Sparse Training — not a diagnostic device.") if __name__ == "__main__": demo.launch()