Create cam_utils.py
Browse files- cam_utils.py +66 -0
cam_utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torch.nn.functional as F
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np, io, base64
|
| 5 |
+
|
| 6 |
+
def _normalize_cam(cam):
|
| 7 |
+
cam = cam - cam.min()
|
| 8 |
+
cam = cam / (cam.max() + 1e-6)
|
| 9 |
+
return cam
|
| 10 |
+
|
| 11 |
+
def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"):
|
| 12 |
+
model.eval()
|
| 13 |
+
tfms = transforms.Compose([
|
| 14 |
+
transforms.Resize(int(img_size*1.15)),
|
| 15 |
+
transforms.CenterCrop(img_size),
|
| 16 |
+
transforms.ToTensor()
|
| 17 |
+
])
|
| 18 |
+
x = tfms(img).unsqueeze(0).to(device)
|
| 19 |
+
x.requires_grad_(True)
|
| 20 |
+
|
| 21 |
+
if target_layer is None: # EfficientNet-B0 last block
|
| 22 |
+
target_layer = model.features[-1][0]
|
| 23 |
+
|
| 24 |
+
activations, grads = [], []
|
| 25 |
+
def fwd_hook(_, __, out): activations.append(out)
|
| 26 |
+
def bwd_hook(_, gin, gout): grads.append(gout[0])
|
| 27 |
+
|
| 28 |
+
h1 = target_layer.register_forward_hook(fwd_hook)
|
| 29 |
+
h2 = target_layer.register_full_backward_hook(bwd_hook)
|
| 30 |
+
|
| 31 |
+
logits = model(x)
|
| 32 |
+
pred = int(logits.argmax(dim=1).item())
|
| 33 |
+
score = logits[0, pred]
|
| 34 |
+
model.zero_grad(set_to_none=True)
|
| 35 |
+
score.backward()
|
| 36 |
+
|
| 37 |
+
A = activations[-1] # (B,C,h,w) typical
|
| 38 |
+
if A.dim() == 4: A = A[0] # (C,h,w)
|
| 39 |
+
elif A.dim() == 3: pass # already (C,h,w)
|
| 40 |
+
else: A = A.mean(dim=0)
|
| 41 |
+
|
| 42 |
+
G = grads[-1]
|
| 43 |
+
if G.dim() == 4: G = G[0] # (C,h,w)
|
| 44 |
+
|
| 45 |
+
if G.shape[0] == A.shape[0]:
|
| 46 |
+
weights = G.mean(dim=(1,2)) # (C,)
|
| 47 |
+
cam = (weights[:, None, None] * A).sum(0) # (h,w)
|
| 48 |
+
else:
|
| 49 |
+
cam = A.mean(dim=0) # safe fallback
|
| 50 |
+
|
| 51 |
+
cam = F.relu(cam)[None, None, ...] # (1,1,h,w)
|
| 52 |
+
cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0]
|
| 53 |
+
cam = _normalize_cam(cam).detach().cpu().numpy() # (H,W)
|
| 54 |
+
|
| 55 |
+
img_np = (x[0].detach().cpu().permute(1,2,0).numpy())
|
| 56 |
+
img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6)
|
| 57 |
+
|
| 58 |
+
import matplotlib.cm as cm
|
| 59 |
+
heat = cm.jet(cam)[..., :3]
|
| 60 |
+
overlay = 0.6*img_np + 0.4*heat
|
| 61 |
+
overlay = np.clip(overlay, 0, 1)
|
| 62 |
+
|
| 63 |
+
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
|
| 64 |
+
|
| 65 |
+
h1.remove(); h2.remove()
|
| 66 |
+
return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam}
|