mgbam commited on
Commit
e0076d2
·
verified ·
1 Parent(s): 2d1deb3

Create cam_utils.py

Browse files
Files changed (1) hide show
  1. cam_utils.py +66 -0
cam_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn.functional as F
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import numpy as np, io, base64
5
+
6
+ def _normalize_cam(cam):
7
+ cam = cam - cam.min()
8
+ cam = cam / (cam.max() + 1e-6)
9
+ return cam
10
+
11
+ def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"):
12
+ model.eval()
13
+ tfms = transforms.Compose([
14
+ transforms.Resize(int(img_size*1.15)),
15
+ transforms.CenterCrop(img_size),
16
+ transforms.ToTensor()
17
+ ])
18
+ x = tfms(img).unsqueeze(0).to(device)
19
+ x.requires_grad_(True)
20
+
21
+ if target_layer is None: # EfficientNet-B0 last block
22
+ target_layer = model.features[-1][0]
23
+
24
+ activations, grads = [], []
25
+ def fwd_hook(_, __, out): activations.append(out)
26
+ def bwd_hook(_, gin, gout): grads.append(gout[0])
27
+
28
+ h1 = target_layer.register_forward_hook(fwd_hook)
29
+ h2 = target_layer.register_full_backward_hook(bwd_hook)
30
+
31
+ logits = model(x)
32
+ pred = int(logits.argmax(dim=1).item())
33
+ score = logits[0, pred]
34
+ model.zero_grad(set_to_none=True)
35
+ score.backward()
36
+
37
+ A = activations[-1] # (B,C,h,w) typical
38
+ if A.dim() == 4: A = A[0] # (C,h,w)
39
+ elif A.dim() == 3: pass # already (C,h,w)
40
+ else: A = A.mean(dim=0)
41
+
42
+ G = grads[-1]
43
+ if G.dim() == 4: G = G[0] # (C,h,w)
44
+
45
+ if G.shape[0] == A.shape[0]:
46
+ weights = G.mean(dim=(1,2)) # (C,)
47
+ cam = (weights[:, None, None] * A).sum(0) # (h,w)
48
+ else:
49
+ cam = A.mean(dim=0) # safe fallback
50
+
51
+ cam = F.relu(cam)[None, None, ...] # (1,1,h,w)
52
+ cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0]
53
+ cam = _normalize_cam(cam).detach().cpu().numpy() # (H,W)
54
+
55
+ img_np = (x[0].detach().cpu().permute(1,2,0).numpy())
56
+ img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6)
57
+
58
+ import matplotlib.cm as cm
59
+ heat = cm.jet(cam)[..., :3]
60
+ overlay = 0.6*img_np + 0.4*heat
61
+ overlay = np.clip(overlay, 0, 1)
62
+
63
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
64
+
65
+ h1.remove(); h2.remove()
66
+ return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam}