import torch, torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np, io, base64 def _normalize_cam(cam): cam = cam - cam.min() cam = cam / (cam.max() + 1e-6) return cam def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"): model.eval() tfms = transforms.Compose([ transforms.Resize(int(img_size*1.15)), transforms.CenterCrop(img_size), transforms.ToTensor() ]) x = tfms(img).unsqueeze(0).to(device) x.requires_grad_(True) if target_layer is None: # EfficientNet-B0 last block target_layer = model.features[-1][0] activations, grads = [], [] def fwd_hook(_, __, out): activations.append(out) def bwd_hook(_, gin, gout): grads.append(gout[0]) h1 = target_layer.register_forward_hook(fwd_hook) h2 = target_layer.register_full_backward_hook(bwd_hook) logits = model(x) pred = int(logits.argmax(dim=1).item()) score = logits[0, pred] model.zero_grad(set_to_none=True) score.backward() A = activations[-1] # (B,C,h,w) typical if A.dim() == 4: A = A[0] # (C,h,w) elif A.dim() == 3: pass # already (C,h,w) else: A = A.mean(dim=0) G = grads[-1] if G.dim() == 4: G = G[0] # (C,h,w) if G.shape[0] == A.shape[0]: weights = G.mean(dim=(1,2)) # (C,) cam = (weights[:, None, None] * A).sum(0) # (h,w) else: cam = A.mean(dim=0) # safe fallback cam = F.relu(cam)[None, None, ...] # (1,1,h,w) cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0] cam = _normalize_cam(cam).detach().cpu().numpy() # (H,W) img_np = (x[0].detach().cpu().permute(1,2,0).numpy()) img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6) import matplotlib.cm as cm heat = cm.jet(cam)[..., :3] overlay = 0.6*img_np + 0.4*heat overlay = np.clip(overlay, 0, 1) probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy() h1.remove(); h2.remove() return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam}