|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import os |
|
|
import io |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
from sklearn.preprocessing import OneHotEncoder |
|
|
from sklearn.model_selection import train_test_split |
|
|
import tqdm |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
from zennit.composites import EpsilonPlusFlat |
|
|
from zennit.canonizers import SequentialMergeBatchNorm |
|
|
from crp.attribution import CondAttribution |
|
|
from crp.concepts import ChannelConcept |
|
|
from crp.helper import get_layer_names |
|
|
from crp.image import imgify |
|
|
|
|
|
epoch = 0 |
|
|
lr = 1e-4 |
|
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
|
load_model = True |
|
|
|
|
|
class_labels = { |
|
|
0: "Central Serous Chorioretinopathy", |
|
|
1: "Diabetic Retinopathy", |
|
|
2: "Disc Edema", |
|
|
3: "Glaucoma", |
|
|
4: "Healthy", |
|
|
5: "Macular Scar", |
|
|
6: "Myopia", |
|
|
7: "Pteryguim", |
|
|
8: "Retinal Detachment", |
|
|
9: "Retinitis Pigmentosa" |
|
|
} |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((324, 324)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
model = models.inception_v3(pretrained=True) |
|
|
|
|
|
num_features = model.fc.in_features |
|
|
model.fc = nn.Linear(num_features, 10) |
|
|
|
|
|
if load_model: |
|
|
model.load_state_dict(torch.load(f"{os.getcwd()}/edd_inception.pth", map_location=device)) |
|
|
model.to(device) |
|
|
|
|
|
layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear]) |
|
|
|
|
|
class GradCAM: |
|
|
def __init__(self, model, target_layer): |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
self.hooks = [] |
|
|
self._register_hooks() |
|
|
|
|
|
def _register_hooks(self): |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
self.activations = output.detach() |
|
|
|
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
self.gradients = grad_output[0].detach() |
|
|
|
|
|
|
|
|
forward_handle = self.target_layer.register_forward_hook(forward_hook) |
|
|
backward_handle = self.target_layer.register_full_backward_hook(backward_hook) |
|
|
|
|
|
|
|
|
self.hooks = [forward_handle, backward_handle] |
|
|
|
|
|
def _remove_hooks(self): |
|
|
for hook in self.hooks: |
|
|
hook.remove() |
|
|
self.hooks.clear() |
|
|
|
|
|
def __del__(self): |
|
|
self._remove_hooks() |
|
|
|
|
|
def generate(self, input_tensor, class_idx=None): |
|
|
self.model.zero_grad() |
|
|
self.model.eval() |
|
|
output = self.model(input_tensor) |
|
|
if class_idx is None: |
|
|
class_idx = output.argmax().item() |
|
|
|
|
|
output[:, class_idx].backward() |
|
|
|
|
|
gradients = self.gradients.cpu().data.numpy()[0] |
|
|
activations = self.activations.cpu().data.numpy()[0] |
|
|
|
|
|
weights = np.mean(gradients, axis=(1, 2)) |
|
|
cam = np.zeros(activations.shape[1:], dtype=np.float32) |
|
|
|
|
|
for i, w in enumerate(weights): |
|
|
cam += w * activations[i] |
|
|
|
|
|
cam = np.maximum(cam, 0) |
|
|
cam = cv2.resize(cam, (input_tensor.shape[2], input_tensor.shape[3])) |
|
|
cam -= np.min(cam) |
|
|
cam /= np.max(cam) |
|
|
return cam |
|
|
|
|
|
class GradCAMPlusPlus: |
|
|
def __init__(self, model, target_layer): |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
self.hooks = [] |
|
|
self._register_hooks() |
|
|
|
|
|
def _register_hooks(self): |
|
|
def forward_hook(module, input, output): |
|
|
self.activations = output.detach() |
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
self.gradients = grad_output[0].detach() |
|
|
|
|
|
self.hooks.append(self.target_layer.register_forward_hook(forward_hook)) |
|
|
self.hooks.append(self.target_layer.register_full_backward_hook(backward_hook)) |
|
|
|
|
|
def __del__(self): |
|
|
for hook in self.hooks: |
|
|
hook.remove() |
|
|
|
|
|
def generate(self, input_tensor, class_idx=None, eps=1e-8): |
|
|
self.model.zero_grad() |
|
|
output = self.model(input_tensor) |
|
|
|
|
|
if class_idx is None: |
|
|
class_idx = output.argmax().item() |
|
|
|
|
|
one_hot = torch.zeros_like(output) |
|
|
one_hot[0, class_idx] = 1 |
|
|
output.backward(gradient=one_hot, retain_graph=True) |
|
|
|
|
|
gradients = self.gradients |
|
|
activations = self.activations |
|
|
b, k, u, v = gradients.size() |
|
|
|
|
|
positive_gradients = F.relu(gradients) |
|
|
alpha_num = positive_gradients.pow(2) |
|
|
alpha_den = 2 * positive_gradients.pow(2) + (activations).sum(dim=[2, 3], keepdim=True) * gradients.pow(3) + eps |
|
|
alpha = alpha_num / alpha_den |
|
|
|
|
|
weights = (alpha * positive_gradients).sum(dim=[2, 3], keepdim=True) |
|
|
cam = (weights * activations).sum(dim=1, keepdim=True).squeeze().cpu().numpy() |
|
|
|
|
|
cam = np.maximum(cam, 0) |
|
|
cam = cv2.resize(cam, (input_tensor.shape[2], input_tensor.shape[3])) |
|
|
cam -= np.min(cam) |
|
|
cam /= np.max(cam) |
|
|
return cam |
|
|
|
|
|
class SmoothGradCAMPlusPlus: |
|
|
def __init__(self, model, target_layer, num_samples=25, noise_level=0.1): |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.num_samples = num_samples |
|
|
self.noise_level = noise_level |
|
|
self.hooks = [] |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
self._register_hooks() |
|
|
self.model.eval() |
|
|
|
|
|
def _register_hooks(self): |
|
|
def forward_hook(module, input, output): |
|
|
self.activations = output.detach() |
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
self.gradients = grad_output[0].detach() |
|
|
|
|
|
self.hooks.append(self.target_layer.register_forward_hook(forward_hook)) |
|
|
self.hooks.append(self.target_layer.register_full_backward_hook(backward_hook)) |
|
|
|
|
|
def _remove_hooks(self): |
|
|
for hook in self.hooks: |
|
|
hook.remove() |
|
|
self.hooks.clear() |
|
|
|
|
|
def __del__(self): |
|
|
self._remove_hooks() |
|
|
|
|
|
def _compute_gradcampp(self, input_tensor, class_idx, eps=1e-8): |
|
|
output = self.model(input_tensor) |
|
|
self.model.zero_grad() |
|
|
|
|
|
one_hot = torch.zeros_like(output) |
|
|
one_hot[0, class_idx] = 1 |
|
|
output.backward(gradient=one_hot, retain_graph=True) |
|
|
|
|
|
gradients = self.gradients |
|
|
activations = self.activations |
|
|
|
|
|
b, k, u, v = gradients.size() |
|
|
positive_gradients = F.relu(gradients) |
|
|
|
|
|
alpha_numerator = positive_gradients.pow(2) |
|
|
alpha_denominator = 2 * positive_gradients.pow(2) + \ |
|
|
(activations * gradients.pow(3)).sum(dim=[2, 3], keepdim=True) + eps |
|
|
alpha = alpha_numerator / alpha_denominator |
|
|
|
|
|
weights = (alpha * positive_gradients).sum(dim=[2, 3], keepdim=True) |
|
|
cam = (weights * activations).sum(dim=1, keepdim=True) |
|
|
cam = cam.squeeze().cpu().detach().numpy() |
|
|
|
|
|
cam = np.maximum(cam, 0) |
|
|
cam = cv2.resize(cam, (input_tensor.shape[3], input_tensor.shape[2])) |
|
|
cam = (cam - cam.min()) / (cam.max() - cam.min() + eps) |
|
|
|
|
|
return cam |
|
|
|
|
|
def generate(self, input_tensor, class_idx=None): |
|
|
if class_idx is None: |
|
|
with torch.no_grad(): |
|
|
output = self.model(input_tensor) |
|
|
class_idx = output.argmax(dim=1).item() |
|
|
|
|
|
smooth_heatmap = np.zeros((input_tensor.shape[2], input_tensor.shape[3])) |
|
|
|
|
|
for _ in range(self.num_samples): |
|
|
noise = torch.randn_like(input_tensor) * self.noise_level |
|
|
noisy_input = input_tensor + noise |
|
|
heatmap = self._compute_gradcampp(noisy_input, class_idx) |
|
|
smooth_heatmap += heatmap |
|
|
|
|
|
smooth_heatmap /= self.num_samples |
|
|
return smooth_heatmap |
|
|
|
|
|
def overlay_heatmap(img, heatmap, alpha=0.5): |
|
|
img = np.array(img) |
|
|
|
|
|
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) |
|
|
|
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET) |
|
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
overlay = cv2.addWeighted(img, alpha, heatmap, 1 - alpha, 0) |
|
|
|
|
|
return overlay |
|
|
|
|
|
target_layer = model.Mixed_7c |
|
|
gradcam = GradCAM(model, target_layer) |
|
|
gradcampp = GradCAMPlusPlus(model, target_layer) |
|
|
smooth_gradcampp = SmoothGradCAMPlusPlus(model, target_layer) |
|
|
|
|
|
def classify_image(img_tensor): |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(img_tensor) |
|
|
label = class_labels[output.argmax(dim=1).item()] |
|
|
|
|
|
return label |
|
|
|
|
|
def generate_gradcams(img,img_tensor, xai_technique): |
|
|
img_tensor.requires_grad = True |
|
|
if xai_technique == "Grad-CAM": |
|
|
cam = gradcam.generate(img_tensor) |
|
|
elif xai_technique == "Grad-CAM++": |
|
|
cam = gradcampp.generate(img_tensor) |
|
|
else: |
|
|
cam = smooth_gradcampp.generate(img_tensor) |
|
|
|
|
|
overlay_cam = overlay_heatmap(img, cam) |
|
|
return overlay_cam |
|
|
|
|
|
def process_image(img): |
|
|
try: |
|
|
if img is None: |
|
|
return "Please upload an image" |
|
|
|
|
|
if transform: |
|
|
img_tensor = transform(img) |
|
|
img_tensor = img_tensor.unsqueeze(0) |
|
|
img_tensor = img_tensor.to(device) |
|
|
|
|
|
label = classify_image(img_tensor) |
|
|
return label |
|
|
except Exception as e: |
|
|
return f"Error processing image: {str(e)}" |
|
|
|
|
|
def crp(img_tensor, selected_layer): |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
output = model(img_tensor) |
|
|
predicted_class = output.argmax(dim=1).item() |
|
|
|
|
|
composite = EpsilonPlusFlat([SequentialMergeBatchNorm()]) |
|
|
attribution = CondAttribution(model, no_param_grad=True) |
|
|
|
|
|
cc = ChannelConcept() |
|
|
mask_map = {name: cc.mask for name in layer_names} |
|
|
conditions = [{'y': [predicted_class]}] |
|
|
attr = attribution(img_tensor, conditions, composite, record_layer=layer_names) |
|
|
rel_c = cc.attribute(attr.relevances[selected_layer], abs_norm=True) |
|
|
rel_values, concept_ids = torch.topk(rel_c[0], 3) |
|
|
conditions = [{selected_layer: [id], 'y': [predicted_class]} for id in concept_ids] |
|
|
|
|
|
heatmap, _, _, _ = attribution(img_tensor, conditions, composite) |
|
|
|
|
|
|
|
|
return imgify(heatmap, symmetric=True, grid=(1, len(concept_ids))), concept_ids.tolist() |
|
|
|
|
|
def process_xai(img, explainability_method, selected_layer): |
|
|
img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
|
|
if explainability_method == "CRP": |
|
|
crp_tensor = img_tensor.clone().detach().requires_grad_() |
|
|
heatmap, concept_ids = crp(crp_tensor, selected_layer) |
|
|
return heatmap, str(concept_ids) |
|
|
else: |
|
|
gradcam_tensor = img_tensor.clone().detach().requires_grad_() |
|
|
heatmap = generate_gradcams(img, gradcam_tensor, explainability_method) |
|
|
return heatmap, None |
|
|
|
|
|
|
|
|
intro_text = """ |
|
|
## π Explainable AI (XAI) in Eye Disease Detection |
|
|
|
|
|
π¨ββοΈ **Meet Amit:** A 45-year-old man experiencing blurry vision. His doctor recommends an AI-based retinal scan. |
|
|
<br> |
|
|
β **The Problem:** Traditional AI models provide a diagnosis but lack transparency. How do we know **why** the model predicted a disease? |
|
|
<br> |
|
|
β
**Solution - XAI:** Explainability techniques like Grad-CAM highlight **which parts of the image influenced the decision**. This helps doctors **validate the AIβs diagnosis**. |
|
|
<br> |
|
|
π· **Try it yourself!** Upload a retinal image, get a prediction, and compare it with XAI-based explanations. |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(intro_text) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("**π AI Diagnosis Without XAI**") |
|
|
input_img = gr.Image(type="pil", label="Upload Retinal Image") |
|
|
diagnosis = gr.Textbox(label="Prediction Only") |
|
|
diagnose_btn = gr.Button("Get Diagnosis") |
|
|
with gr.Column(): |
|
|
gr.Markdown("**π AI Diagnosis With XAI Explanation**") |
|
|
xai_method = gr.Radio(["Grad-CAM", "Grad-CAM++", "Smooth Grad-CAM++", "CRP"], label="Select XAI Method") |
|
|
layer_selector = gr.Dropdown( |
|
|
choices=layer_names, |
|
|
label="Select Layer for CRP", |
|
|
visible=False, |
|
|
value="Mixed_7c.branch_pool.conv" |
|
|
) |
|
|
concept_ids_box = gr.Textbox( |
|
|
label="Concept IDs (for CRP)", |
|
|
interactive=False, |
|
|
visible=False |
|
|
) |
|
|
xai_img = gr.Image(label="XAI Visualization") |
|
|
xai_btn = gr.Button("Analysis") |
|
|
|
|
|
def toggle_layer_visibility(method): |
|
|
is_crp = (method == "CRP") |
|
|
return [ |
|
|
gr.update(visible=is_crp), |
|
|
gr.update(visible=is_crp) |
|
|
] |
|
|
|
|
|
xai_method.change( |
|
|
toggle_layer_visibility, |
|
|
inputs=[xai_method], |
|
|
outputs=[layer_selector, concept_ids_box] |
|
|
) |
|
|
|
|
|
diagnose_btn.click( |
|
|
process_image, |
|
|
inputs=[input_img], |
|
|
outputs=[diagnosis] |
|
|
) |
|
|
xai_btn.click( |
|
|
process_xai, |
|
|
inputs=[input_img, xai_method, layer_selector], |
|
|
outputs=[xai_img, concept_ids_box] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |