# app.py — WOW Edition: Stunning Malaria Detection AI with Advanced Features # Developer: Oluwafemi Idiakhoa # ---- Environment fixes ---- import os os.environ["OMP_NUM_THREADS"] = "1" import io, json, glob import numpy as np import pandas as pd from PIL import Image import gradio as gr import torch, torch.nn as nn from torchvision import models, transforms, datasets import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns from huggingface_hub import hf_hub_download # ---- Local utils (Grad-CAM) ---- from cam_utils import grad_cam # ===================== Config ===================== MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0") NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2")) IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224")) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CLASS_NAMES = ["Parasitized", "Uninfected"] # Model weights resolution HF_REPO_ID = os.environ.get("HF_REPO_ID", "").strip() HF_WEIGHTS = os.environ.get("HF_WEIGHTS", "best.pt").strip() if HF_REPO_ID else "" WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt") def resolve_weights() -> str: if HF_REPO_ID: try: path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_WEIGHTS) return path except Exception as e: print(f"[hub] failed to download {HF_REPO_ID}:{HF_WEIGHTS} → {e}") if os.path.exists(WEIGHTS_PATH): return WEIGHTS_PATH candidates = ["checkpoints/best.pt", "checkpoints/last.pt"] + sorted(glob.glob("checkpoints/*.pt")) for p in candidates: if os.path.exists(p): return p raise FileNotFoundError("No checkpoint found. Upload a .pt file.") # ===================== Model ===================== def build_model(name: str, num_classes: int): name = name.lower() if name == "efficientnet_b0": m = models.efficientnet_b0(weights=None) m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes) return m elif name == "resnet50": m = models.resnet50(weights=None) m.fc = nn.Linear(m.fc.in_features, num_classes) return m else: raise ValueError(f"Unsupported model: {name}") CHECKPOINT_FILE = resolve_weights() _model = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE) _state = torch.load(CHECKPOINT_FILE, map_location=DEVICE) try: _model.load_state_dict(_state) except Exception: _model.load_state_dict(_state["state_dict"]) _model.eval() _pre = transforms.Compose([ transforms.Resize(int(IMAGE_SIZE*1.15)), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), ]) # ===================== Enhanced Inference ===================== def predict_enhanced(image: Image.Image, show_cam: bool): if image is None: return "Please upload an image", None, None, None img = image.convert("RGB") x = _pre(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = _model(x).cpu().numpy().squeeze() probs = np.exp(logits - logits.max()) probs = probs / probs.sum() pred_idx = int(np.argmax(probs)) pred_class = CLASS_NAMES[pred_idx] confidence = float(probs[pred_idx]) # Color-coded results if pred_idx == 0: # Parasitized color = "#FF4444" status = "MALARIA DETECTED" emoji = "đŸĻ " recommendation = """ ### Clinical Recommendation: - **Immediate microscopic confirmation required** - **Consult healthcare provider immediately** - **Begin rapid diagnostic test (RDT)** - **Consider antimalarial treatment if confirmed** """ else: # Uninfected color = "#44FF88" status = "NO MALARIA DETECTED" emoji = "✅" recommendation = """ ### Clinical Recommendation: - **Negative result - low malaria likelihood** - **Monitor for symptoms development** - **Consult healthcare provider if symptoms persist** - **Consider other differential diagnoses** """ # Enhanced result HTML result_html = f"""
{emoji}

{status}

AI-Powered Analysis Complete

Diagnostic Results

Prediction

{pred_class}

Confidence

{confidence*100:.2f}%

Class Probabilities

đŸĻ  Parasitized {probs[0]*100:.2f}%
✅ Uninfected {probs[1]*100:.2f}%

âš ī¸ Medical Disclaimer

This is a research tool only and NOT a medical diagnostic device. Results must be confirmed by certified laboratory testing and qualified healthcare professionals. Do not make medical decisions based solely on this AI analysis.

""" # Generate Grad-CAM overlay = None cam_img = None if show_cam: try: cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=DEVICE) overlay = Image.fromarray((cam["overlay"]*255).astype("uint8")) cam_img = Image.fromarray((cam["heatmap"]*255).astype("uint8")) except Exception as e: print(f"Grad-CAM error: {e}") # Create probability chart fig, ax = plt.subplots(figsize=(6, 4)) colors_bar = ['#FF4444', '#44FF88'] bars = ax.barh(CLASS_NAMES, probs*100, color=colors_bar, alpha=0.8) ax.set_xlabel('Probability (%)', fontsize=12, fontweight='bold') ax.set_title('Prediction Confidence', fontsize=14, fontweight='bold', pad=20) ax.set_xlim(0, 100) ax.grid(axis='x', alpha=0.3) for i, (bar, prob) in enumerate(zip(bars, probs)): ax.text(prob*100 + 2, i, f'{prob*100:.1f}%', va='center', fontweight='bold') plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) prob_chart = Image.open(buf).convert('RGB') plt.close() return result_html, overlay, prob_chart, recommendation # ===================== ONNX Export ===================== def export_onnx(precision: str): m = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE) state = torch.load(CHECKPOINT_FILE, map_location=DEVICE) try: m.load_state_dict(state) except Exception: m.load_state_dict(state["state_dict"]) m.eval() dummy = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE) if precision == "fp16": m = m.half() dummy = dummy.half() dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} buf = io.BytesIO() torch.onnx.export( m, dummy, buf, input_names=["input"], output_names=["output"], dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True ) fname = f"model_{MODEL_NAME}_{precision}_{IMAGE_SIZE}.onnx" buf.seek(0) return fname, buf # ===================== Validation ===================== def validate(zip_file): import tempfile, zipfile if zip_file is None: return "Please upload a validation dataset ZIP file.", None, None tmp = tempfile.mkdtemp() with zipfile.ZipFile(zip_file.name, 'r') as zf: zf.extractall(tmp) ds = datasets.ImageFolder(tmp, transform=_pre) dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2) ys, ps = [], [] with torch.no_grad(): for xb, yb in dl: preds = _model(xb.to(DEVICE)).argmax(1).cpu().numpy() ys.extend(yb.numpy()) ps.extend(preds) import sklearn.metrics as sk rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True) cm = sk.confusion_matrix(ys, ps) # Enhanced confusion matrix fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt="d", cmap="RdYlGn_r", xticklabels=ds.classes, yticklabels=ds.classes, ax=ax, cbar_kws={'label': 'Count'}, linewidths=2, linecolor='white') ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold') ax.set_ylabel('True Label', fontsize=12, fontweight='bold') ax.set_title('Confusion Matrix - Validation Results', fontsize=14, fontweight='bold', pad=20) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=160) buf.seek(0) cm_img = Image.open(buf).convert("RGB") plt.close() # Format report acc = rep['accuracy'] report_md = f""" ### Validation Results **Overall Accuracy:** {acc*100:.2f}% #### Per-Class Metrics: | Class | Precision | Recall | F1-Score | Support | |-------|-----------|--------|----------|---------| | Parasitized | {rep['Parasitized']['precision']:.3f} | {rep['Parasitized']['recall']:.3f} | {rep['Parasitized']['f1-score']:.3f} | {rep['Parasitized']['support']:.0f} | | Uninfected | {rep['Uninfected']['precision']:.3f} | {rep['Uninfected']['recall']:.3f} | {rep['Uninfected']['f1-score']:.3f} | {rep['Uninfected']['support']:.0f} | **Macro Avg:** Precision={rep['macro avg']['precision']:.3f}, Recall={rep['macro avg']['recall']:.3f}, F1={rep['macro avg']['f1-score']:.3f} """ return report_md, cm_img # ===================== Dashboard ===================== METRICS_DEFAULT = "checkpoints/metrics.csv" def _plot_to_pil(df: pd.DataFrame, ycol: str, title: str, ylabel: str, color='#2196F3'): if ycol not in df.columns: return None s = df[["epoch", ycol]].dropna() if len(s) == 0: return None fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2.5, markersize=8, color=color) ax.fill_between(s["epoch"], s[ycol], alpha=0.3, color=color) ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_xlabel("Epoch", fontsize=12, fontweight='bold') ax.set_ylabel(ylabel, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white') buf.seek(0) img = Image.open(buf).convert("RGB") plt.close() return img def load_metrics(path: str): if not os.path.exists(path): return "Metrics file not found. Upload your training metrics CSV.", None, None, None, None, None, None try: df = pd.read_csv(path) except Exception as e: return f"Error reading CSV: {e}", None, None, None, None, None, None # Normalize column names for compatibility col_map = { 'activation_rate': 'act_rate', 'energy_savings': 'save_rate' } for old_col, new_col in col_map.items(): if old_col in df.columns and new_col not in df.columns: df[new_col] = df[old_col] # Summary statistics energy_col = 'save_rate' if 'save_rate' in df.columns else 'energy_savings' # Try markdown table, fallback to CSV format if tabulate is missing try: last_5_table = df.tail(5).to_markdown(index=False) except ImportError: last_5_table = "```\n" + df.tail(5).to_string(index=False) + "\n```" summary = f""" ### Training Summary **Total Epochs:** {len(df)} **Best Validation Accuracy:** {df['val_acc'].max()*100:.2f}% (Epoch {df['val_acc'].idxmax() + 1}) **Final Training Loss:** {df['train_loss'].iloc[-1]:.4f} **Average Energy Savings:** {df[energy_col].mean()*100:.1f}% #### Last 5 Epochs: {last_5_table} """ fig_loss = _plot_to_pil(df, "train_loss", "Training Loss Over Time", "Loss", color='#FF5722') fig_acc = _plot_to_pil(df, "val_acc", "Validation Accuracy Over Time", "Accuracy", color='#4CAF50') fig_act = _plot_to_pil(df, "act_rate", "Activation Rate (AST)", "Activation Rate", color='#2196F3') fig_save = _plot_to_pil(df, "save_rate", "Energy Savings (AST)", "Savings Fraction", color='#9C27B0') fig_thr = _plot_to_pil(df, "threshold", "Activation Threshold (AST)", "Threshold", color='#FF9800') csv_bytes = df.to_csv(index=False).encode("utf-8") return summary, fig_loss, fig_acc, fig_act, fig_save, fig_thr, ("metrics.csv", csv_bytes) def compare_runs(files): if not files or len(files) == 0: return "Upload 2 or more metrics.csv files to compare training runs.", None, None, None, None, None runs = [] for f in files: try: df = pd.read_csv(f.name) # Normalize column names col_map = {'activation_rate': 'act_rate', 'energy_savings': 'save_rate'} for old_col, new_col in col_map.items(): if old_col in df.columns and new_col not in df.columns: df[new_col] = df[old_col] runs.append((os.path.basename(f.name), df)) except Exception as e: return f"Error reading {f.name}: {e}", None, None, None, None, None def _overlay_plot(runs, ycol, title, ylabel): fig, ax = plt.subplots(figsize=(10, 6)) colors = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800', '#00BCD4'] found = False for i, (name, df) in enumerate(runs): if ycol in df.columns: s = df[["epoch", ycol]].dropna() if len(s): color = colors[i % len(colors)] ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2, markersize=6, label=name, color=color, alpha=0.8) found = True if not found: plt.close(fig) return None ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_xlabel("Epoch", fontsize=12, fontweight='bold') ax.set_ylabel(ylabel, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--') ax.legend(fontsize=10, loc='best') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white') buf.seek(0) img = Image.open(buf).convert("RGB") plt.close() return img msg = f"Successfully compared {len(runs)} training runs." p_loss = _overlay_plot(runs, "train_loss", "Training Loss Comparison", "Loss") p_acc = _overlay_plot(runs, "val_acc", "Validation Accuracy Comparison", "Accuracy") p_act = _overlay_plot(runs, "act_rate", "Activation Rate Comparison", "Activation Rate") p_save = _overlay_plot(runs, "save_rate", "Energy Savings Comparison", "Savings Fraction") p_thr = _overlay_plot(runs, "threshold", "Threshold Comparison", "Threshold") return msg, p_loss, p_acc, p_act, p_save, p_thr # ===================== Custom CSS ===================== custom_css = """ #header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 30px; border-radius: 15px; margin-bottom: 30px; text-align: center; color: white; } #header h1 { margin: 0; font-size: 42px; font-weight: 800; } #header p { margin: 10px 0 0 0; font-size: 18px; opacity: 0.95; } .badge { display: inline-block; padding: 8px 16px; margin: 5px; background: rgba(255,255,255,0.2); border-radius: 20px; font-weight: 600; } """ # ===================== Gradio UI ===================== with gr.Blocks(title="Malaria Detection AI - Advanced Diagnostics", css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML(""" """) gr.Markdown(f""" ### System Information **Model:** `{MODEL_NAME}` | **Weights:** `{os.path.basename(CHECKPOINT_FILE)}` | **Device:** `{DEVICE}` | **Classes:** {NUM_CLASSES} """) # -------- Main Inference Tab -------- with gr.Tab("🔍 Diagnosis"): gr.Markdown(""" ### Upload Blood Smear Image for Analysis Upload a microscopy image of a blood cell to detect malaria parasites using AI. """) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Upload Blood Cell Image", height=400) show_cam = gr.Checkbox(value=True, label="Show Grad-CAM Visualization (Explainable AI)") btn_pred = gr.Button("đŸ”Ŧ Analyze for Malaria", variant="primary", size="lg") with gr.Column(scale=1): result_out = gr.HTML(label="Diagnostic Results") with gr.Row(): with gr.Column(): cam_out = gr.Image(type="pil", label="Grad-CAM Heat Map (Where AI Looks)") with gr.Column(): chart_out = gr.Image(type="pil", label="Confidence Distribution") recommendation_out = gr.Markdown(label="Clinical Recommendations") btn_pred.click( fn=predict_enhanced, inputs=[img_in, show_cam], outputs=[result_out, cam_out, chart_out, recommendation_out] ) # -------- Validation Tab -------- with gr.Tab("✅ Model Validation"): gr.Markdown(""" ### Validate Model Performance Upload a ZIP file containing a validation dataset (with Parasitized/ and Uninfected/ folders). """) val_zip = gr.File(label="Upload Validation Dataset (.zip)", file_types=[".zip"]) btn_eval = gr.Button("📊 Run Validation", variant="primary") with gr.Row(): with gr.Column(): rep_out = gr.Markdown(label="Classification Report") with gr.Column(): cm_img = gr.Image(type="pil", label="Confusion Matrix") btn_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img]) # -------- Dashboard Tab -------- with gr.Tab("📈 Training Dashboard"): gr.Markdown(""" ### Visualize Training Metrics View training progress, validation accuracy, and energy savings from Adaptive Sparse Training (AST). **Metrics are automatically loaded from checkpoints/metrics.csv. Upload a different file if needed.** """) # Load initial metrics from checkpoints initial_summary, initial_loss, initial_acc, initial_act, initial_save, initial_thr, initial_csv = load_metrics(METRICS_DEFAULT) # Optional: File upload to override with gr.Row(): single_metrics_file = gr.File(label="Upload different metrics.csv (optional)", file_types=[".csv"]) btn_upload = gr.Button("📊 Load Metrics", variant="secondary") summary_md = gr.Markdown(value=initial_summary) with gr.Row(): plot_loss = gr.Image(label="Training Loss", value=initial_loss) plot_acc = gr.Image(label="Validation Accuracy", value=initial_acc) with gr.Row(): plot_act = gr.Image(label="Activation Rate (AST)", value=initial_act) plot_save = gr.Image(label="Energy Savings (AST)", value=initial_save) plot_thr = gr.Image(label="Activation Threshold (AST)", value=initial_thr) dl_btn = gr.DownloadButton(label="âŦ‡ī¸ Download Metrics CSV") # Update download button with initial CSV data demo.load( fn=lambda: initial_csv, inputs=[], outputs=[dl_btn] ) # Connect upload button for custom metrics btn_upload.click( fn=lambda f: load_metrics(f.name if f else ""), inputs=[single_metrics_file], outputs=[summary_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn] ) gr.Markdown("---") gr.Markdown("### Compare Multiple Training Runs") mult = gr.Files(label="Upload Multiple metrics.csv Files", file_types=[".csv"]) cmp_msg = gr.Markdown() with gr.Row(): p_loss = gr.Image(label="Loss Comparison") p_acc = gr.Image(label="Accuracy Comparison") with gr.Row(): p_act = gr.Image(label="Activation Comparison") p_save = gr.Image(label="Savings Comparison") p_thr = gr.Image(label="Threshold Comparison") mult.upload( fn=compare_runs, inputs=[mult], outputs=[cmp_msg, p_loss, p_acc, p_act, p_save, p_thr] ) # -------- Export Tab -------- with gr.Tab("đŸ“Ļ Model Export"): gr.Markdown(""" ### Export Model to ONNX Format Convert the PyTorch model to ONNX format for production deployment and cross-platform compatibility. """) with gr.Row(): with gr.Column(): onnx_precision = gr.Radio( choices=["fp32", "fp16"], value="fp32", label="Precision", info="FP16 for faster inference, FP32 for maximum accuracy" ) btn_onnx = gr.Button("🚀 Export to ONNX", variant="primary") with gr.Column(): onnx_file = gr.File(label="Download ONNX Model", interactive=False) def onnx_wrap(prec): fname, fobj = export_onnx(prec) return (fname, fobj) btn_onnx.click(fn=onnx_wrap, inputs=[onnx_precision], outputs=[onnx_file]) # -------- About Tab -------- with gr.Tab("â„šī¸ About"): gr.Markdown(""" ## About This System ### Technology Stack - **Deep Learning Framework:** PyTorch - **Model Architecture:** EfficientNet-B0 - **Training Method:** Adaptive Sparse Training (AST) with Sundew Algorithm - **Explainable AI:** Grad-CAM (Gradient-weighted Class Activation Mapping) - **Dataset:** NIH Malaria Cell Images (27,558 samples) ### Performance Metrics - **Validation Accuracy:** 93.94% (final epoch), 94.63% (best epoch) - **Energy Savings:** 88% reduction in training cost vs. traditional methods - **Inference Speed:** <1 second per image - **Model Size:** ~16MB - **Training:** 30 epochs on NIH Malaria Dataset ### Key Features 1. **Real-time Diagnosis:** Upload blood smear images for instant analysis 2. **Explainable AI:** Grad-CAM shows exactly where the model detects parasites 3. **Energy Efficient:** Trained using Adaptive Sparse Training with Sundew algorithm for 88% energy savings 4. **Clinical Recommendations:** Actionable advice based on predictions 5. **Model Validation:** Built-in tools for performance evaluation 6. **ONNX Export:** Deploy anywhere with standard model format ### Use Cases - **Research:** Academic studies on malaria detection - **Education:** Teaching AI applications in healthcare - **Triage:** Rapid pre-screening in resource-limited settings - **Model Comparison:** Benchmark against other approaches ### Important Disclaimers âš ī¸ **This is a research prototype and NOT a medical device.** - Results must be confirmed by certified laboratory testing - Do not use for clinical diagnosis without professional validation - Always consult qualified healthcare providers - This tool is for research and educational purposes only ### Developer **Oluwafemi Idiakhoa** ### Citation If you use this system in your research, please cite: ``` @software{malaria_ast_detection, author = {Idiakhoa, Oluwafemi}, title = {Malaria Detection using Adaptive Sparse Training}, year = {2025}, url = {https://github.com/oluwafemidiakhoa/Malaria} } ``` ### Resources - [GitHub Repository](https://github.com/oluwafemidiakhoa/Malaria) - AST Library (PyPI): pypi.org/project/adaptive-sparse-training - [NIH Malaria Dataset](https://lhncbc.nlm.nih.gov/LHC-research/LHC-projects/image-processing/malaria-datasheet.html) --- Built with EfficientNet-B0 + Adaptive Sparse Training | Powered by PyTorch & Gradio """) # Footer gr.Markdown(""" ---

Malaria Detection AI | Advanced Deep Learning for Global Health

Developer: Oluwafemi Idiakhoa | GitHub

This is a research tool. Not for clinical use.

""") if __name__ == "__main__": demo.launch()