# app.py â WOW Edition: Stunning Malaria Detection AI with Advanced Features
# Developer: Oluwafemi Idiakhoa
# ---- Environment fixes ----
import os
os.environ["OMP_NUM_THREADS"] = "1"
import io, json, glob
import numpy as np
import pandas as pd
from PIL import Image
import gradio as gr
import torch, torch.nn as nn
from torchvision import models, transforms, datasets
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from huggingface_hub import hf_hub_download
# ---- Local utils (Grad-CAM) ----
from cam_utils import grad_cam
# ===================== Config =====================
MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0")
NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2"))
IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_NAMES = ["Parasitized", "Uninfected"]
# Model weights resolution
HF_REPO_ID = os.environ.get("HF_REPO_ID", "").strip()
HF_WEIGHTS = os.environ.get("HF_WEIGHTS", "best.pt").strip() if HF_REPO_ID else ""
WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt")
def resolve_weights() -> str:
if HF_REPO_ID:
try:
path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_WEIGHTS)
return path
except Exception as e:
print(f"[hub] failed to download {HF_REPO_ID}:{HF_WEIGHTS} â {e}")
if os.path.exists(WEIGHTS_PATH):
return WEIGHTS_PATH
candidates = ["checkpoints/best.pt", "checkpoints/last.pt"] + sorted(glob.glob("checkpoints/*.pt"))
for p in candidates:
if os.path.exists(p):
return p
raise FileNotFoundError("No checkpoint found. Upload a .pt file.")
# ===================== Model =====================
def build_model(name: str, num_classes: int):
name = name.lower()
if name == "efficientnet_b0":
m = models.efficientnet_b0(weights=None)
m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes)
return m
elif name == "resnet50":
m = models.resnet50(weights=None)
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
else:
raise ValueError(f"Unsupported model: {name}")
CHECKPOINT_FILE = resolve_weights()
_model = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE)
_state = torch.load(CHECKPOINT_FILE, map_location=DEVICE)
try:
_model.load_state_dict(_state)
except Exception:
_model.load_state_dict(_state["state_dict"])
_model.eval()
_pre = transforms.Compose([
transforms.Resize(int(IMAGE_SIZE*1.15)),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
])
# ===================== Enhanced Inference =====================
def predict_enhanced(image: Image.Image, show_cam: bool):
if image is None:
return "Please upload an image", None, None, None
img = image.convert("RGB")
x = _pre(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = _model(x).cpu().numpy().squeeze()
probs = np.exp(logits - logits.max())
probs = probs / probs.sum()
pred_idx = int(np.argmax(probs))
pred_class = CLASS_NAMES[pred_idx]
confidence = float(probs[pred_idx])
# Color-coded results
if pred_idx == 0: # Parasitized
color = "#FF4444"
status = "MALARIA DETECTED"
emoji = "đĻ "
recommendation = """
### Clinical Recommendation:
- **Immediate microscopic confirmation required**
- **Consult healthcare provider immediately**
- **Begin rapid diagnostic test (RDT)**
- **Consider antimalarial treatment if confirmed**
"""
else: # Uninfected
color = "#44FF88"
status = "NO MALARIA DETECTED"
emoji = "â
"
recommendation = """
### Clinical Recommendation:
- **Negative result - low malaria likelihood**
- **Monitor for symptoms development**
- **Consult healthcare provider if symptoms persist**
- **Consider other differential diagnoses**
"""
# Enhanced result HTML
result_html = f"""
{emoji}
{status}
AI-Powered Analysis Complete
Diagnostic Results
Confidence
{confidence*100:.2f}%
Class Probabilities
đĻ Parasitized
{probs[0]*100:.2f}%
â
Uninfected
{probs[1]*100:.2f}%
â ī¸ Medical Disclaimer
This is a research tool only and NOT a medical diagnostic device.
Results must be confirmed by certified laboratory testing and qualified healthcare professionals.
Do not make medical decisions based solely on this AI analysis.
"""
# Generate Grad-CAM
overlay = None
cam_img = None
if show_cam:
try:
cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=DEVICE)
overlay = Image.fromarray((cam["overlay"]*255).astype("uint8"))
cam_img = Image.fromarray((cam["heatmap"]*255).astype("uint8"))
except Exception as e:
print(f"Grad-CAM error: {e}")
# Create probability chart
fig, ax = plt.subplots(figsize=(6, 4))
colors_bar = ['#FF4444', '#44FF88']
bars = ax.barh(CLASS_NAMES, probs*100, color=colors_bar, alpha=0.8)
ax.set_xlabel('Probability (%)', fontsize=12, fontweight='bold')
ax.set_title('Prediction Confidence', fontsize=14, fontweight='bold', pad=20)
ax.set_xlim(0, 100)
ax.grid(axis='x', alpha=0.3)
for i, (bar, prob) in enumerate(zip(bars, probs)):
ax.text(prob*100 + 2, i, f'{prob*100:.1f}%', va='center', fontweight='bold')
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
prob_chart = Image.open(buf).convert('RGB')
plt.close()
return result_html, overlay, prob_chart, recommendation
# ===================== ONNX Export =====================
def export_onnx(precision: str):
m = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE)
state = torch.load(CHECKPOINT_FILE, map_location=DEVICE)
try:
m.load_state_dict(state)
except Exception:
m.load_state_dict(state["state_dict"])
m.eval()
dummy = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
if precision == "fp16":
m = m.half()
dummy = dummy.half()
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
buf = io.BytesIO()
torch.onnx.export(
m, dummy, buf,
input_names=["input"], output_names=["output"],
dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True
)
fname = f"model_{MODEL_NAME}_{precision}_{IMAGE_SIZE}.onnx"
buf.seek(0)
return fname, buf
# ===================== Validation =====================
def validate(zip_file):
import tempfile, zipfile
if zip_file is None:
return "Please upload a validation dataset ZIP file.", None, None
tmp = tempfile.mkdtemp()
with zipfile.ZipFile(zip_file.name, 'r') as zf:
zf.extractall(tmp)
ds = datasets.ImageFolder(tmp, transform=_pre)
dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2)
ys, ps = [], []
with torch.no_grad():
for xb, yb in dl:
preds = _model(xb.to(DEVICE)).argmax(1).cpu().numpy()
ys.extend(yb.numpy())
ps.extend(preds)
import sklearn.metrics as sk
rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True)
cm = sk.confusion_matrix(ys, ps)
# Enhanced confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="RdYlGn_r",
xticklabels=ds.classes, yticklabels=ds.classes,
ax=ax, cbar_kws={'label': 'Count'}, linewidths=2, linecolor='white')
ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12, fontweight='bold')
ax.set_title('Confusion Matrix - Validation Results', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=160)
buf.seek(0)
cm_img = Image.open(buf).convert("RGB")
plt.close()
# Format report
acc = rep['accuracy']
report_md = f"""
### Validation Results
**Overall Accuracy:** {acc*100:.2f}%
#### Per-Class Metrics:
| Class | Precision | Recall | F1-Score | Support |
|-------|-----------|--------|----------|---------|
| Parasitized | {rep['Parasitized']['precision']:.3f} | {rep['Parasitized']['recall']:.3f} | {rep['Parasitized']['f1-score']:.3f} | {rep['Parasitized']['support']:.0f} |
| Uninfected | {rep['Uninfected']['precision']:.3f} | {rep['Uninfected']['recall']:.3f} | {rep['Uninfected']['f1-score']:.3f} | {rep['Uninfected']['support']:.0f} |
**Macro Avg:** Precision={rep['macro avg']['precision']:.3f}, Recall={rep['macro avg']['recall']:.3f}, F1={rep['macro avg']['f1-score']:.3f}
"""
return report_md, cm_img
# ===================== Dashboard =====================
METRICS_DEFAULT = "checkpoints/metrics.csv"
def _plot_to_pil(df: pd.DataFrame, ycol: str, title: str, ylabel: str, color='#2196F3'):
if ycol not in df.columns:
return None
s = df[["epoch", ycol]].dropna()
if len(s) == 0:
return None
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2.5, markersize=8, color=color)
ax.fill_between(s["epoch"], s[ycol], alpha=0.3, color=color)
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel("Epoch", fontsize=12, fontweight='bold')
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white')
buf.seek(0)
img = Image.open(buf).convert("RGB")
plt.close()
return img
def load_metrics(path: str):
if not os.path.exists(path):
return "Metrics file not found. Upload your training metrics CSV.", None, None, None, None, None, None
try:
df = pd.read_csv(path)
except Exception as e:
return f"Error reading CSV: {e}", None, None, None, None, None, None
# Normalize column names for compatibility
col_map = {
'activation_rate': 'act_rate',
'energy_savings': 'save_rate'
}
for old_col, new_col in col_map.items():
if old_col in df.columns and new_col not in df.columns:
df[new_col] = df[old_col]
# Summary statistics
energy_col = 'save_rate' if 'save_rate' in df.columns else 'energy_savings'
# Try markdown table, fallback to CSV format if tabulate is missing
try:
last_5_table = df.tail(5).to_markdown(index=False)
except ImportError:
last_5_table = "```\n" + df.tail(5).to_string(index=False) + "\n```"
summary = f"""
### Training Summary
**Total Epochs:** {len(df)}
**Best Validation Accuracy:** {df['val_acc'].max()*100:.2f}% (Epoch {df['val_acc'].idxmax() + 1})
**Final Training Loss:** {df['train_loss'].iloc[-1]:.4f}
**Average Energy Savings:** {df[energy_col].mean()*100:.1f}%
#### Last 5 Epochs:
{last_5_table}
"""
fig_loss = _plot_to_pil(df, "train_loss", "Training Loss Over Time", "Loss", color='#FF5722')
fig_acc = _plot_to_pil(df, "val_acc", "Validation Accuracy Over Time", "Accuracy", color='#4CAF50')
fig_act = _plot_to_pil(df, "act_rate", "Activation Rate (AST)", "Activation Rate", color='#2196F3')
fig_save = _plot_to_pil(df, "save_rate", "Energy Savings (AST)", "Savings Fraction", color='#9C27B0')
fig_thr = _plot_to_pil(df, "threshold", "Activation Threshold (AST)", "Threshold", color='#FF9800')
csv_bytes = df.to_csv(index=False).encode("utf-8")
return summary, fig_loss, fig_acc, fig_act, fig_save, fig_thr, ("metrics.csv", csv_bytes)
def compare_runs(files):
if not files or len(files) == 0:
return "Upload 2 or more metrics.csv files to compare training runs.", None, None, None, None, None
runs = []
for f in files:
try:
df = pd.read_csv(f.name)
# Normalize column names
col_map = {'activation_rate': 'act_rate', 'energy_savings': 'save_rate'}
for old_col, new_col in col_map.items():
if old_col in df.columns and new_col not in df.columns:
df[new_col] = df[old_col]
runs.append((os.path.basename(f.name), df))
except Exception as e:
return f"Error reading {f.name}: {e}", None, None, None, None, None
def _overlay_plot(runs, ycol, title, ylabel):
fig, ax = plt.subplots(figsize=(10, 6))
colors = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800', '#00BCD4']
found = False
for i, (name, df) in enumerate(runs):
if ycol in df.columns:
s = df[["epoch", ycol]].dropna()
if len(s):
color = colors[i % len(colors)]
ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2,
markersize=6, label=name, color=color, alpha=0.8)
found = True
if not found:
plt.close(fig)
return None
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel("Epoch", fontsize=12, fontweight='bold')
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(fontsize=10, loc='best')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white')
buf.seek(0)
img = Image.open(buf).convert("RGB")
plt.close()
return img
msg = f"Successfully compared {len(runs)} training runs."
p_loss = _overlay_plot(runs, "train_loss", "Training Loss Comparison", "Loss")
p_acc = _overlay_plot(runs, "val_acc", "Validation Accuracy Comparison", "Accuracy")
p_act = _overlay_plot(runs, "act_rate", "Activation Rate Comparison", "Activation Rate")
p_save = _overlay_plot(runs, "save_rate", "Energy Savings Comparison", "Savings Fraction")
p_thr = _overlay_plot(runs, "threshold", "Threshold Comparison", "Threshold")
return msg, p_loss, p_acc, p_act, p_save, p_thr
# ===================== Custom CSS =====================
custom_css = """
#header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 30px;
border-radius: 15px;
margin-bottom: 30px;
text-align: center;
color: white;
}
#header h1 {
margin: 0;
font-size: 42px;
font-weight: 800;
}
#header p {
margin: 10px 0 0 0;
font-size: 18px;
opacity: 0.95;
}
.badge {
display: inline-block;
padding: 8px 16px;
margin: 5px;
background: rgba(255,255,255,0.2);
border-radius: 20px;
font-weight: 600;
}
"""
# ===================== Gradio UI =====================
with gr.Blocks(title="Malaria Detection AI - Advanced Diagnostics", css=custom_css, theme=gr.themes.Soft()) as demo:
# Header
gr.HTML("""
""")
gr.Markdown(f"""
### System Information
**Model:** `{MODEL_NAME}` | **Weights:** `{os.path.basename(CHECKPOINT_FILE)}` | **Device:** `{DEVICE}` | **Classes:** {NUM_CLASSES}
""")
# -------- Main Inference Tab --------
with gr.Tab("đ Diagnosis"):
gr.Markdown("""
### Upload Blood Smear Image for Analysis
Upload a microscopy image of a blood cell to detect malaria parasites using AI.
""")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="Upload Blood Cell Image", height=400)
show_cam = gr.Checkbox(value=True, label="Show Grad-CAM Visualization (Explainable AI)")
btn_pred = gr.Button("đŦ Analyze for Malaria", variant="primary", size="lg")
with gr.Column(scale=1):
result_out = gr.HTML(label="Diagnostic Results")
with gr.Row():
with gr.Column():
cam_out = gr.Image(type="pil", label="Grad-CAM Heat Map (Where AI Looks)")
with gr.Column():
chart_out = gr.Image(type="pil", label="Confidence Distribution")
recommendation_out = gr.Markdown(label="Clinical Recommendations")
btn_pred.click(
fn=predict_enhanced,
inputs=[img_in, show_cam],
outputs=[result_out, cam_out, chart_out, recommendation_out]
)
# -------- Validation Tab --------
with gr.Tab("â
Model Validation"):
gr.Markdown("""
### Validate Model Performance
Upload a ZIP file containing a validation dataset (with Parasitized/ and Uninfected/ folders).
""")
val_zip = gr.File(label="Upload Validation Dataset (.zip)", file_types=[".zip"])
btn_eval = gr.Button("đ Run Validation", variant="primary")
with gr.Row():
with gr.Column():
rep_out = gr.Markdown(label="Classification Report")
with gr.Column():
cm_img = gr.Image(type="pil", label="Confusion Matrix")
btn_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img])
# -------- Dashboard Tab --------
with gr.Tab("đ Training Dashboard"):
gr.Markdown("""
### Visualize Training Metrics
View training progress, validation accuracy, and energy savings from Adaptive Sparse Training (AST).
**Metrics are automatically loaded from checkpoints/metrics.csv. Upload a different file if needed.**
""")
# Load initial metrics from checkpoints
initial_summary, initial_loss, initial_acc, initial_act, initial_save, initial_thr, initial_csv = load_metrics(METRICS_DEFAULT)
# Optional: File upload to override
with gr.Row():
single_metrics_file = gr.File(label="Upload different metrics.csv (optional)", file_types=[".csv"])
btn_upload = gr.Button("đ Load Metrics", variant="secondary")
summary_md = gr.Markdown(value=initial_summary)
with gr.Row():
plot_loss = gr.Image(label="Training Loss", value=initial_loss)
plot_acc = gr.Image(label="Validation Accuracy", value=initial_acc)
with gr.Row():
plot_act = gr.Image(label="Activation Rate (AST)", value=initial_act)
plot_save = gr.Image(label="Energy Savings (AST)", value=initial_save)
plot_thr = gr.Image(label="Activation Threshold (AST)", value=initial_thr)
dl_btn = gr.DownloadButton(label="âŦī¸ Download Metrics CSV")
# Update download button with initial CSV data
demo.load(
fn=lambda: initial_csv,
inputs=[],
outputs=[dl_btn]
)
# Connect upload button for custom metrics
btn_upload.click(
fn=lambda f: load_metrics(f.name if f else ""),
inputs=[single_metrics_file],
outputs=[summary_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn]
)
gr.Markdown("---")
gr.Markdown("### Compare Multiple Training Runs")
mult = gr.Files(label="Upload Multiple metrics.csv Files", file_types=[".csv"])
cmp_msg = gr.Markdown()
with gr.Row():
p_loss = gr.Image(label="Loss Comparison")
p_acc = gr.Image(label="Accuracy Comparison")
with gr.Row():
p_act = gr.Image(label="Activation Comparison")
p_save = gr.Image(label="Savings Comparison")
p_thr = gr.Image(label="Threshold Comparison")
mult.upload(
fn=compare_runs,
inputs=[mult],
outputs=[cmp_msg, p_loss, p_acc, p_act, p_save, p_thr]
)
# -------- Export Tab --------
with gr.Tab("đĻ Model Export"):
gr.Markdown("""
### Export Model to ONNX Format
Convert the PyTorch model to ONNX format for production deployment and cross-platform compatibility.
""")
with gr.Row():
with gr.Column():
onnx_precision = gr.Radio(
choices=["fp32", "fp16"],
value="fp32",
label="Precision",
info="FP16 for faster inference, FP32 for maximum accuracy"
)
btn_onnx = gr.Button("đ Export to ONNX", variant="primary")
with gr.Column():
onnx_file = gr.File(label="Download ONNX Model", interactive=False)
def onnx_wrap(prec):
fname, fobj = export_onnx(prec)
return (fname, fobj)
btn_onnx.click(fn=onnx_wrap, inputs=[onnx_precision], outputs=[onnx_file])
# -------- About Tab --------
with gr.Tab("âšī¸ About"):
gr.Markdown("""
## About This System
### Technology Stack
- **Deep Learning Framework:** PyTorch
- **Model Architecture:** EfficientNet-B0
- **Training Method:** Adaptive Sparse Training (AST) with Sundew Algorithm
- **Explainable AI:** Grad-CAM (Gradient-weighted Class Activation Mapping)
- **Dataset:** NIH Malaria Cell Images (27,558 samples)
### Performance Metrics
- **Validation Accuracy:** 93.94% (final epoch), 94.63% (best epoch)
- **Energy Savings:** 88% reduction in training cost vs. traditional methods
- **Inference Speed:** <1 second per image
- **Model Size:** ~16MB
- **Training:** 30 epochs on NIH Malaria Dataset
### Key Features
1. **Real-time Diagnosis:** Upload blood smear images for instant analysis
2. **Explainable AI:** Grad-CAM shows exactly where the model detects parasites
3. **Energy Efficient:** Trained using Adaptive Sparse Training with Sundew algorithm for 88% energy savings
4. **Clinical Recommendations:** Actionable advice based on predictions
5. **Model Validation:** Built-in tools for performance evaluation
6. **ONNX Export:** Deploy anywhere with standard model format
### Use Cases
- **Research:** Academic studies on malaria detection
- **Education:** Teaching AI applications in healthcare
- **Triage:** Rapid pre-screening in resource-limited settings
- **Model Comparison:** Benchmark against other approaches
### Important Disclaimers
â ī¸ **This is a research prototype and NOT a medical device.**
- Results must be confirmed by certified laboratory testing
- Do not use for clinical diagnosis without professional validation
- Always consult qualified healthcare providers
- This tool is for research and educational purposes only
### Developer
**Oluwafemi Idiakhoa**
### Citation
If you use this system in your research, please cite:
```
@software{malaria_ast_detection,
author = {Idiakhoa, Oluwafemi},
title = {Malaria Detection using Adaptive Sparse Training},
year = {2025},
url = {https://github.com/oluwafemidiakhoa/Malaria}
}
```
### Resources
- [GitHub Repository](https://github.com/oluwafemidiakhoa/Malaria)
- AST Library (PyPI): pypi.org/project/adaptive-sparse-training
- [NIH Malaria Dataset](https://lhncbc.nlm.nih.gov/LHC-research/LHC-projects/image-processing/malaria-datasheet.html)
---
Built with EfficientNet-B0 + Adaptive Sparse Training | Powered by PyTorch & Gradio
""")
# Footer
gr.Markdown("""
---
Malaria Detection AI | Advanced Deep Learning for Global Health
Developer: Oluwafemi Idiakhoa | GitHub
This is a research tool. Not for clinical use.
""")
if __name__ == "__main__":
demo.launch()