stefanosgikas commited on
Commit
fb56b04
·
verified ·
1 Parent(s): 333dbc8

Upload 2 files

Browse files
architecture/efficientvit.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import itertools
3
+
4
+ from timm.models.vision_transformer import trunc_normal_
5
+ from timm.models.layers import SqueezeExcite
6
+ from timm.models.registry import register_model
7
+
8
+ class Conv2d_BN(torch.nn.Sequential):
9
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
10
+ groups=1, bn_weight_init=1, resolution=-10000):
11
+ super().__init__()
12
+ self.add_module('c', torch.nn.Conv2d(
13
+ a, b, ks, stride, pad, dilation, groups, bias=False))
14
+ self.add_module('bn', torch.nn.BatchNorm2d(b))
15
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
16
+ torch.nn.init.constant_(self.bn.bias, 0)
17
+
18
+ @torch.no_grad()
19
+ def fuse(self):
20
+ c, bn = self._modules.values()
21
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
22
+ w = c.weight * w[:, None, None, None]
23
+ b = bn.bias - bn.running_mean * bn.weight / \
24
+ (bn.running_var + bn.eps)**0.5
25
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
26
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
27
+ m.weight.data.copy_(w)
28
+ m.bias.data.copy_(b)
29
+ return m
30
+
31
+
32
+ class BN_Linear(torch.nn.Sequential):
33
+ def __init__(self, a, b, bias=True, std=0.02):
34
+ super().__init__()
35
+ self.add_module('bn', torch.nn.BatchNorm1d(a))
36
+ self.add_module('l', torch.nn.Linear(a, b, bias=bias))
37
+ trunc_normal_(self.l.weight, std=std)
38
+ if bias:
39
+ torch.nn.init.constant_(self.l.bias, 0)
40
+
41
+ @torch.no_grad()
42
+ def fuse(self):
43
+ bn, l = self._modules.values()
44
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
45
+ b = bn.bias - self.bn.running_mean * \
46
+ self.bn.weight / (bn.running_var + bn.eps)**0.5
47
+ w = l.weight * w[None, :]
48
+ if l.bias is None:
49
+ b = b @ self.l.weight.T
50
+ else:
51
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
52
+ m = torch.nn.Linear(w.size(1), w.size(0))
53
+ m.weight.data.copy_(w)
54
+ m.bias.data.copy_(b)
55
+ return m
56
+
57
+
58
+ class PatchMerging(torch.nn.Module):
59
+ def __init__(self, dim, out_dim, input_resolution):
60
+ super().__init__()
61
+ hid_dim = int(dim * 4)
62
+ self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, resolution=input_resolution)
63
+ self.act = torch.nn.ReLU()
64
+ self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, resolution=input_resolution)
65
+ self.se = SqueezeExcite(hid_dim, .25)
66
+ self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, resolution=input_resolution // 2)
67
+
68
+ def forward(self, x):
69
+ x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
70
+ return x
71
+
72
+
73
+ class Residual(torch.nn.Module):
74
+ def __init__(self, m, drop=0.):
75
+ super().__init__()
76
+ self.m = m
77
+ self.drop = drop
78
+
79
+ def forward(self, x):
80
+ if self.training and self.drop > 0:
81
+ return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
82
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
83
+ else:
84
+ return x + self.m(x)
85
+
86
+
87
+ class FFN(torch.nn.Module):
88
+ def __init__(self, ed, h, resolution):
89
+ super().__init__()
90
+ self.pw1 = Conv2d_BN(ed, h, resolution=resolution)
91
+ self.act = torch.nn.ReLU()
92
+ self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0, resolution=resolution)
93
+
94
+ def forward(self, x):
95
+ x = self.pw2(self.act(self.pw1(x)))
96
+ return x
97
+
98
+
99
+ class CascadedGroupAttention(torch.nn.Module):
100
+ r""" Cascaded Group Attention.
101
+
102
+ Args:
103
+ dim (int): Number of input channels.
104
+ key_dim (int): The dimension for query and key.
105
+ num_heads (int): Number of attention heads.
106
+ attn_ratio (int): Multiplier for the query dim for value dimension.
107
+ resolution (int): Input resolution, correspond to the window size.
108
+ kernels (List[int]): The kernel size of the dw conv on query.
109
+ """
110
+ def __init__(self, dim, key_dim, num_heads=8,
111
+ attn_ratio=4,
112
+ resolution=14,
113
+ kernels=[5, 5, 5, 5],):
114
+ super().__init__()
115
+ self.num_heads = num_heads
116
+ self.scale = key_dim ** -0.5
117
+ self.key_dim = key_dim
118
+ self.d = int(attn_ratio * key_dim)
119
+ self.attn_ratio = attn_ratio
120
+
121
+ qkvs = []
122
+ dws = []
123
+ for i in range(num_heads):
124
+ qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
125
+ dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution))
126
+ self.qkvs = torch.nn.ModuleList(qkvs)
127
+ self.dws = torch.nn.ModuleList(dws)
128
+ self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
129
+ self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
130
+
131
+ points = list(itertools.product(range(resolution), range(resolution)))
132
+ N = len(points)
133
+ attention_offsets = {}
134
+ idxs = []
135
+ for p1 in points:
136
+ for p2 in points:
137
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
138
+ if offset not in attention_offsets:
139
+ attention_offsets[offset] = len(attention_offsets)
140
+ idxs.append(attention_offsets[offset])
141
+ self.attention_biases = torch.nn.Parameter(
142
+ torch.zeros(num_heads, len(attention_offsets)))
143
+ self.register_buffer('attention_bias_idxs',
144
+ torch.LongTensor(idxs).view(N, N))
145
+
146
+ @torch.no_grad()
147
+ def train(self, mode=True):
148
+ super().train(mode)
149
+ if mode and hasattr(self, 'ab'):
150
+ del self.ab
151
+ else:
152
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
153
+
154
+ def forward(self, x): # x (B,C,H,W)
155
+ B, C, H, W = x.shape
156
+ trainingab = self.attention_biases[:, self.attention_bias_idxs]
157
+ feats_in = x.chunk(len(self.qkvs), dim=1)
158
+ feats_out = []
159
+ feat = feats_in[0]
160
+ for i, qkv in enumerate(self.qkvs):
161
+ if i > 0: # add the previous output to the input
162
+ feat = feat + feats_in[i]
163
+ feat = qkv(feat)
164
+ q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
165
+ q = self.dws[i](q)
166
+ q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
167
+ attn = (
168
+ (q.transpose(-2, -1) @ k) * self.scale
169
+ +
170
+ (trainingab[i] if self.training else self.ab[i])
171
+ )
172
+ attn = attn.softmax(dim=-1) # BNN
173
+ feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
174
+ feats_out.append(feat)
175
+ x = self.proj(torch.cat(feats_out, 1))
176
+ return x
177
+
178
+
179
+ class LocalWindowAttention(torch.nn.Module):
180
+ r""" Local Window Attention.
181
+
182
+ Args:
183
+ dim (int): Number of input channels.
184
+ key_dim (int): The dimension for query and key.
185
+ num_heads (int): Number of attention heads.
186
+ attn_ratio (int): Multiplier for the query dim for value dimension.
187
+ resolution (int): Input resolution.
188
+ window_resolution (int): Local window resolution.
189
+ kernels (List[int]): The kernel size of the dw conv on query.
190
+ """
191
+ def __init__(self, dim, key_dim, num_heads=8,
192
+ attn_ratio=4,
193
+ resolution=14,
194
+ window_resolution=7,
195
+ kernels=[5, 5, 5, 5],):
196
+ super().__init__()
197
+ self.dim = dim
198
+ self.num_heads = num_heads
199
+ self.resolution = resolution
200
+ assert window_resolution > 0, 'window_size must be greater than 0'
201
+ self.window_resolution = window_resolution
202
+
203
+ window_resolution = min(window_resolution, resolution)
204
+ self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
205
+ attn_ratio=attn_ratio,
206
+ resolution=window_resolution,
207
+ kernels=kernels,)
208
+
209
+ def forward(self, x):
210
+ H = W = self.resolution
211
+ B, C, H_, W_ = x.shape
212
+ # Only check this for classifcation models
213
+ assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
214
+
215
+ if H <= self.window_resolution and W <= self.window_resolution:
216
+ x = self.attn(x)
217
+ else:
218
+ x = x.permute(0, 2, 3, 1)
219
+ pad_b = (self.window_resolution - H %
220
+ self.window_resolution) % self.window_resolution
221
+ pad_r = (self.window_resolution - W %
222
+ self.window_resolution) % self.window_resolution
223
+ padding = pad_b > 0 or pad_r > 0
224
+
225
+ if padding:
226
+ x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
227
+
228
+ pH, pW = H + pad_b, W + pad_r
229
+ nH = pH // self.window_resolution
230
+ nW = pW // self.window_resolution
231
+ # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
232
+ x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
233
+ B * nH * nW, self.window_resolution, self.window_resolution, C
234
+ ).permute(0, 3, 1, 2)
235
+ x = self.attn(x)
236
+ # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
237
+ x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
238
+ C).transpose(2, 3).reshape(B, pH, pW, C)
239
+ if padding:
240
+ x = x[:, :H, :W].contiguous()
241
+ x = x.permute(0, 3, 1, 2)
242
+ return x
243
+
244
+
245
+ class EfficientViTBlock(torch.nn.Module):
246
+ """ A basic EfficientViT building block.
247
+
248
+ Args:
249
+ type (str): Type for token mixer. Default: 's' for self-attention.
250
+ ed (int): Number of input channels.
251
+ kd (int): Dimension for query and key in the token mixer.
252
+ nh (int): Number of attention heads.
253
+ ar (int): Multiplier for the query dim for value dimension.
254
+ resolution (int): Input resolution.
255
+ window_resolution (int): Local window resolution.
256
+ kernels (List[int]): The kernel size of the dw conv on query.
257
+ """
258
+ def __init__(self, type,
259
+ ed, kd, nh=8,
260
+ ar=4,
261
+ resolution=14,
262
+ window_resolution=7,
263
+ kernels=[5, 5, 5, 5],):
264
+ super().__init__()
265
+
266
+ self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
267
+ self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution))
268
+
269
+ if type == 's':
270
+ self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \
271
+ resolution=resolution, window_resolution=window_resolution, kernels=kernels))
272
+
273
+ self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
274
+ self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution))
275
+
276
+ def forward(self, x):
277
+ return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
278
+
279
+
280
+ class EfficientViT(torch.nn.Module):
281
+ def __init__(self, img_size=224,
282
+ patch_size=16,
283
+ in_chans=3,
284
+ num_classes=1000,
285
+ stages=['s', 's', 's'],
286
+ embed_dim=[64, 128, 192],
287
+ key_dim=[16, 16, 16],
288
+ depth=[1, 2, 3],
289
+ num_heads=[4, 4, 4],
290
+ window_size=[7, 7, 7],
291
+ kernels=[5, 5, 5, 5],
292
+ down_ops=[['subsample', 2], ['subsample', 2], ['']],
293
+ distillation=False,):
294
+ super().__init__()
295
+
296
+ resolution = img_size
297
+ # Patch embedding
298
+ self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(),
299
+ Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(),
300
+ Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(),
301
+ Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1, resolution=resolution // 8))
302
+
303
+ resolution = img_size // patch_size
304
+ attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
305
+ self.blocks1 = []
306
+ self.blocks2 = []
307
+ self.blocks3 = []
308
+
309
+ # Build EfficientViT blocks
310
+ for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(
311
+ zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
312
+ for d in range(dpth):
313
+ eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))
314
+ if do[0] == 'subsample':
315
+ # Build EfficientViT downsample block
316
+ #('Subsample' stride)
317
+ blk = eval('self.blocks' + str(i+2))
318
+ resolution_ = (resolution - 1) // do[1] + 1
319
+ blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)),
320
+ Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),))
321
+ blk.append(PatchMerging(*embed_dim[i:i + 2], resolution))
322
+ resolution = resolution_
323
+ blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)),
324
+ Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),))
325
+ self.blocks1 = torch.nn.Sequential(*self.blocks1)
326
+ self.blocks2 = torch.nn.Sequential(*self.blocks2)
327
+ self.blocks3 = torch.nn.Sequential(*self.blocks3)
328
+
329
+ # Classification head
330
+ self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
331
+ self.distillation = distillation
332
+ if distillation:
333
+ self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
338
+
339
+ def forward(self, x):
340
+ x = self.patch_embed(x)
341
+ x = self.blocks1(x)
342
+ x = self.blocks2(x)
343
+ x = self.blocks3(x)
344
+ x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
345
+ if self.distillation:
346
+ x = self.head(x), self.head_dist(x)
347
+ if not self.training:
348
+ x = (x[0] + x[1]) / 2
349
+ else:
350
+ x = self.head(x)
351
+ return x
352
+
353
+
354
+ EfficientViT_d = {
355
+ 'img_size': 224,
356
+ 'patch_size': 16,
357
+ 'embed_dim': [96, 144, 400], #192, 288, 384
358
+ 'depth': [1, 3, 4], #1, 3, 4 -----------------[1, 1, 2]
359
+ 'num_heads': [3, 3, 4], #3, 3, 4
360
+ 'window_size': [7, 7, 7],
361
+ 'kernels': [7, 5, 3, 3],
362
+ }
363
+
364
+ EfficientViT_w = {
365
+ 'img_size': 224,
366
+ 'patch_size': 16,
367
+ 'embed_dim': [192, 288, 96], #400 192
368
+ 'depth': [1, 1, 1], #1, 3, 4 -----------------[1, 1, 2]
369
+ 'num_heads': [3, 3, 4], #3, 3, 4
370
+ 'window_size': [7, 7, 7],
371
+ 'kernels': [7, 5, 3, 3],
372
+ }
373
+
374
+
375
+
376
+ @register_model
377
+ def EfficientViT_d(num_classes=5, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_d):
378
+ model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg)
379
+
380
+ if fuse:
381
+ replace_batchnorm(model)
382
+ return model
383
+
384
+
385
+ @register_model
386
+ def EfficientViT_w(num_classes=5, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_w):
387
+ model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg)
388
+
389
+ if fuse:
390
+ replace_batchnorm(model)
391
+ return model
392
+
393
+
394
+
395
+ def replace_batchnorm(net):
396
+ for child_name, child in net.named_children():
397
+ if hasattr(child, 'fuse'):
398
+ setattr(net, child_name, child.fuse())
399
+ elif isinstance(child, torch.nn.BatchNorm2d):
400
+ setattr(net, child_name, torch.nn.Identity())
401
+ else:
402
+ replace_batchnorm(child)
architecture/spectformer.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from timm.models.registry import register_model
8
+ from timm.models.vision_transformer import _cfg
9
+ import math
10
+ import numpy as np
11
+ from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
12
+
13
+
14
+ class SpectralGatingNetwork(nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ # this weights are valid for h=14 and w=8
18
+ if dim == 64: #96 for large model, 64 for small and base model
19
+ self.h = 56 #H
20
+ self.w = 29 #(W/2)+1
21
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
22
+ if dim ==128:
23
+ self.h = 28 #H
24
+ self.w = 15 #(W/2)+1, this is due to rfft2
25
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
26
+ if dim == 96: #96 for large model, 64 for small and base model
27
+ self.h = 56 #H
28
+ self.w = 29 #(W/2)+1
29
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
30
+ if dim ==192:
31
+ self.h = 28 #H
32
+ self.w = 15 #(W/2)+1, this is due to rfft2
33
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
34
+
35
+ def forward(self, x, H, W):
36
+ # print('wno',x.shape) #CIFAR100 image :[128, 196, 384]
37
+ B, N, C = x.shape
38
+ # print('wno B, N, C',B, N, C) #CIFAR100 image : 128 196 384
39
+ x = x.view(B, H, W, C)
40
+ # B, H, W, C=x.shape
41
+ x = x.to(torch.float32)
42
+ # print(x.dtype)
43
+ # Add above for this error, RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
44
+ x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
45
+ # print('wno',x.shape)
46
+ weight = torch.view_as_complex(self.complex_weight)
47
+ # print('weight',weight.shape)
48
+ x = x * weight
49
+ x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
50
+ # print('wno',x.shape)
51
+ x = x.reshape(B, N, C)# permute is not same as reshape or view
52
+ return x
53
+ #return x, weight
54
+
55
+
56
+ def rand_bbox(size, lam, scale=1):
57
+ W = size[1] // scale
58
+ H = size[2] // scale
59
+ cut_rat = np.sqrt(1. - lam)
60
+ cut_w = np.int(W * cut_rat)
61
+ cut_h = np.int(H * cut_rat)
62
+
63
+ # uniform
64
+ cx = np.random.randint(W)
65
+ cy = np.random.randint(H)
66
+
67
+ bbx1 = np.clip(cx - cut_w // 2, 0, W)
68
+ bby1 = np.clip(cy - cut_h // 2, 0, H)
69
+ bbx2 = np.clip(cx + cut_w // 2, 0, W)
70
+ bby2 = np.clip(cy + cut_h // 2, 0, H)
71
+
72
+ return bbx1, bby1, bbx2, bby2
73
+
74
+ class ClassAttention(nn.Module):
75
+ def __init__(self, dim, num_heads):
76
+ super().__init__()
77
+ self.num_heads = num_heads
78
+ head_dim = dim // num_heads
79
+ self.head_dim = head_dim
80
+ self.scale = head_dim**-0.5
81
+ self.kv = nn.Linear(dim, dim * 2)
82
+ self.q = nn.Linear(dim, dim)
83
+ self.proj = nn.Linear(dim, dim)
84
+ self.apply(self._init_weights)
85
+
86
+ def _init_weights(self, m):
87
+ if isinstance(m, nn.Linear):
88
+ trunc_normal_(m.weight, std=.02)
89
+ if isinstance(m, nn.Linear) and m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+ elif isinstance(m, nn.LayerNorm):
92
+ nn.init.constant_(m.bias, 0)
93
+ nn.init.constant_(m.weight, 1.0)
94
+ elif isinstance(m, nn.Conv2d):
95
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96
+ fan_out //= m.groups
97
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
98
+ if m.bias is not None:
99
+ m.bias.data.zero_()
100
+
101
+ def forward(self, x):
102
+ B, N, C = x.shape
103
+ kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
104
+ k, v = kv[0], kv[1]
105
+ q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)
106
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
107
+ attn = attn.softmax(dim=-1)
108
+ cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
109
+ cls_embed = self.proj(cls_embed)
110
+ return cls_embed
111
+
112
+ class FFN(nn.Module):
113
+ def __init__(self, in_features, hidden_features):
114
+ super().__init__()
115
+ self.fc1 = nn.Linear(in_features, hidden_features)
116
+ self.act = nn.GELU()
117
+ self.fc2 = nn.Linear(hidden_features, in_features)
118
+ self.apply(self._init_weights)
119
+
120
+ def _init_weights(self, m):
121
+ if isinstance(m, nn.Linear):
122
+ trunc_normal_(m.weight, std=.02)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+ elif isinstance(m, nn.Conv2d):
129
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ fan_out //= m.groups
131
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
132
+ if m.bias is not None:
133
+ m.bias.data.zero_()
134
+
135
+ def forward(self, x):
136
+ x = self.fc1(x)
137
+ x = self.act(x)
138
+ x = self.fc2(x)
139
+ return x
140
+
141
+ class ClassBlock(nn.Module):
142
+ def __init__(self, dim, num_heads, mlp_ratio, norm_layer=nn.LayerNorm):
143
+ super().__init__()
144
+ self.norm1 = norm_layer(dim)
145
+ self.norm2 = norm_layer(dim)
146
+ self.attn = ClassAttention(dim, num_heads)
147
+ self.mlp = FFN(dim, int(dim * mlp_ratio))
148
+ self.apply(self._init_weights)
149
+
150
+ def _init_weights(self, m):
151
+ if isinstance(m, nn.Linear):
152
+ trunc_normal_(m.weight, std=.02)
153
+ if isinstance(m, nn.Linear) and m.bias is not None:
154
+ nn.init.constant_(m.bias, 0)
155
+ elif isinstance(m, nn.LayerNorm):
156
+ nn.init.constant_(m.bias, 0)
157
+ nn.init.constant_(m.weight, 1.0)
158
+ elif isinstance(m, nn.Conv2d):
159
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
160
+ fan_out //= m.groups
161
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
162
+ if m.bias is not None:
163
+ m.bias.data.zero_()
164
+
165
+ def forward(self, x):
166
+ cls_embed = x[:, :1]
167
+ cls_embed = cls_embed + self.attn(self.norm1(x))
168
+ cls_embed = cls_embed + self.mlp(self.norm2(cls_embed))
169
+ return torch.cat([cls_embed, x[:, 1:]], dim=1)
170
+
171
+ class PVT2FFN(nn.Module):
172
+ def __init__(self, in_features, hidden_features):
173
+ super().__init__()
174
+ self.fc1 = nn.Linear(in_features, hidden_features)
175
+ self.dwconv = DWConv(hidden_features)
176
+ self.act = nn.GELU()
177
+ self.fc2 = nn.Linear(hidden_features, in_features)
178
+ self.apply(self._init_weights)
179
+
180
+ def _init_weights(self, m):
181
+ if isinstance(m, nn.Linear):
182
+ trunc_normal_(m.weight, std=.02)
183
+ if isinstance(m, nn.Linear) and m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.LayerNorm):
186
+ nn.init.constant_(m.bias, 0)
187
+ nn.init.constant_(m.weight, 1.0)
188
+ elif isinstance(m, nn.Conv2d):
189
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
190
+ fan_out //= m.groups
191
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
192
+ if m.bias is not None:
193
+ m.bias.data.zero_()
194
+
195
+ def forward(self, x, H, W):
196
+ x = self.fc1(x)
197
+ x = self.dwconv(x, H, W)
198
+ x = self.act(x)
199
+ x = self.fc2(x)
200
+ return x
201
+
202
+ class Attention(nn.Module):
203
+ def __init__(self, dim, num_heads):
204
+ super().__init__()
205
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
206
+
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim ** -0.5
211
+
212
+ self.q = nn.Linear(dim, dim)
213
+ self.kv = nn.Linear(dim, dim * 2)
214
+ self.proj = nn.Linear(dim, dim)
215
+ self.apply(self._init_weights)
216
+
217
+ def _init_weights(self, m):
218
+ if isinstance(m, nn.Linear):
219
+ trunc_normal_(m.weight, std=.02)
220
+ if isinstance(m, nn.Linear) and m.bias is not None:
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.LayerNorm):
223
+ nn.init.constant_(m.bias, 0)
224
+ nn.init.constant_(m.weight, 1.0)
225
+ elif isinstance(m, nn.Conv2d):
226
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
227
+ fan_out //= m.groups
228
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
229
+ if m.bias is not None:
230
+ m.bias.data.zero_()
231
+
232
+ def forward(self, x, H, W):
233
+ B, N, C = x.shape
234
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
235
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
236
+ k, v = kv[0], kv[1]
237
+ attn = (q @ k.transpose(-2, -1)) * self.scale
238
+ attn = attn.softmax(dim=-1)
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
240
+ x = self.proj(x)
241
+ #return x
242
+ return x, attn
243
+
244
+ class Block(nn.Module):
245
+ def __init__(self,
246
+ dim,
247
+ num_heads,
248
+ mlp_ratio,
249
+ drop_path=0.,
250
+ norm_layer=nn.LayerNorm,
251
+ sr_ratio=1,
252
+ block_type = 'wave'
253
+ ):
254
+ super().__init__()
255
+ self.norm1 = norm_layer(dim)
256
+ self.norm2 = norm_layer(dim)
257
+
258
+ if block_type == 'std_att':
259
+ self.attn = Attention(dim, num_heads)
260
+ else:
261
+ self.attn = SpectralGatingNetwork(dim)
262
+ self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
263
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
264
+ self.apply(self._init_weights)
265
+
266
+ def _init_weights(self, m):
267
+ if isinstance(m, nn.Linear):
268
+ trunc_normal_(m.weight, std=.02)
269
+ if isinstance(m, nn.Linear) and m.bias is not None:
270
+ nn.init.constant_(m.bias, 0)
271
+ elif isinstance(m, nn.LayerNorm):
272
+ nn.init.constant_(m.bias, 0)
273
+ nn.init.constant_(m.weight, 1.0)
274
+ elif isinstance(m, nn.Conv2d):
275
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
276
+ fan_out //= m.groups
277
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
278
+ if m.bias is not None:
279
+ m.bias.data.zero_()
280
+
281
+ # def forward(self, x, H, W): ## !!!!!!!!!!!!!!!!
282
+ # x = x + self.drop_path(self.attn(self.norm1(x), H, W))
283
+ # x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
284
+ # return x
285
+
286
+
287
+ def forward(self, x, H, W):
288
+ attn_output, attn_weights = self.attn(self.norm1(x), H, W) if isinstance(self.attn, Attention) else (self.attn(self.norm1(x), H, W), None)
289
+ x = x + self.drop_path(attn_output)
290
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
291
+
292
+ # Optionally return attention weights for visualization or analysis
293
+ return (x, attn_weights) if attn_weights is not None else x
294
+
295
+
296
+ class DownSamples(nn.Module):
297
+ def __init__(self, in_channels, out_channels):
298
+ super().__init__()
299
+ self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
300
+ self.norm = nn.LayerNorm(out_channels)
301
+ self.apply(self._init_weights)
302
+
303
+ def _init_weights(self, m):
304
+ if isinstance(m, nn.Linear):
305
+ trunc_normal_(m.weight, std=.02)
306
+ if isinstance(m, nn.Linear) and m.bias is not None:
307
+ nn.init.constant_(m.bias, 0)
308
+ elif isinstance(m, nn.LayerNorm):
309
+ nn.init.constant_(m.bias, 0)
310
+ nn.init.constant_(m.weight, 1.0)
311
+ elif isinstance(m, nn.Conv2d):
312
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
313
+ fan_out //= m.groups
314
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
315
+ if m.bias is not None:
316
+ m.bias.data.zero_()
317
+
318
+ def forward(self, x):
319
+ x = self.proj(x)
320
+ _, _, H, W = x.shape
321
+ x = x.flatten(2).transpose(1, 2)
322
+ x = self.norm(x)
323
+ return x, H, W
324
+
325
+ class Stem(nn.Module):
326
+ def __init__(self, in_channels, stem_hidden_dim, out_channels):
327
+ super().__init__()
328
+ hidden_dim = stem_hidden_dim
329
+ self.conv = nn.Sequential(
330
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
331
+ padding=3, bias=False), # 112x112
332
+ nn.BatchNorm2d(hidden_dim),
333
+ nn.ReLU(inplace=True),
334
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
335
+ padding=1, bias=False), # 112x112
336
+ nn.BatchNorm2d(hidden_dim),
337
+ nn.ReLU(inplace=True),
338
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
339
+ padding=1, bias=False), # 112x112
340
+ nn.BatchNorm2d(hidden_dim),
341
+ nn.ReLU(inplace=True),
342
+ )
343
+ self.proj = nn.Conv2d(hidden_dim,
344
+ out_channels,
345
+ kernel_size=3,
346
+ stride=2,
347
+ padding=1)
348
+ self.norm = nn.LayerNorm(out_channels)
349
+
350
+ self.apply(self._init_weights)
351
+
352
+ def _init_weights(self, m):
353
+ if isinstance(m, nn.Linear):
354
+ trunc_normal_(m.weight, std=.02)
355
+ if isinstance(m, nn.Linear) and m.bias is not None:
356
+ nn.init.constant_(m.bias, 0)
357
+ elif isinstance(m, nn.LayerNorm):
358
+ nn.init.constant_(m.bias, 0)
359
+ nn.init.constant_(m.weight, 1.0)
360
+ elif isinstance(m, nn.Conv2d):
361
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
362
+ fan_out //= m.groups
363
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
364
+ if m.bias is not None:
365
+ m.bias.data.zero_()
366
+
367
+ def forward(self, x):
368
+ x = self.conv(x)
369
+ x = self.proj(x)
370
+ _, _, H, W = x.shape
371
+ x = x.flatten(2).transpose(1, 2)
372
+ x = self.norm(x)
373
+ return x, H, W
374
+
375
+ class SpectFormer(nn.Module):
376
+ def __init__(self,
377
+ in_chans=3,
378
+ num_classes=1000,
379
+ stem_hidden_dim = 32,
380
+ embed_dims=[64, 128, 320, 448],
381
+ num_heads=[2, 4, 10, 14],
382
+ mlp_ratios=[8, 8, 4, 4],
383
+ drop_path_rate=0.,
384
+ norm_layer=nn.LayerNorm,
385
+ depths=[3, 4, 6, 3],
386
+ sr_ratios=[4, 2, 1, 1],
387
+ num_stages=4,
388
+ token_label=False,
389
+ **kwargs
390
+ ):
391
+ super().__init__()
392
+ self.num_classes = num_classes
393
+ self.depths = depths
394
+ self.num_stages = num_stages
395
+
396
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
397
+ cur = 0
398
+
399
+ for i in range(num_stages):
400
+ if i == 0:
401
+ patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
402
+ else:
403
+ patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
404
+
405
+ block = nn.ModuleList([Block(
406
+ dim = embed_dims[i],
407
+ num_heads = num_heads[i],
408
+ mlp_ratio = mlp_ratios[i],
409
+ drop_path=dpr[cur + j],
410
+ norm_layer=norm_layer,
411
+ sr_ratio = sr_ratios[i],
412
+ block_type='wave' if i < 2 else 'std_att')
413
+ for j in range(depths[i])])
414
+
415
+ norm = norm_layer(embed_dims[i])
416
+ cur += depths[i]
417
+
418
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
419
+ setattr(self, f"block{i + 1}", block)
420
+ setattr(self, f"norm{i + 1}", norm)
421
+
422
+ post_layers = ['ca']
423
+ self.post_network = nn.ModuleList([
424
+ ClassBlock(
425
+ dim = embed_dims[-1],
426
+ num_heads = num_heads[-1],
427
+ mlp_ratio = mlp_ratios[-1],
428
+ norm_layer=norm_layer)
429
+ for _ in range(len(post_layers))
430
+ ])
431
+
432
+ # classification head
433
+ self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
434
+ ##################################### token_label #####################################
435
+ self.return_dense = token_label
436
+ self.mix_token = token_label
437
+ self.beta = 1.0
438
+ self.pooling_scale = 8
439
+ if self.return_dense:
440
+ self.aux_head = nn.Linear(
441
+ embed_dims[-1],
442
+ num_classes) if num_classes > 0 else nn.Identity()
443
+ ##################################### token_label #####################################
444
+
445
+ self.apply(self._init_weights)
446
+
447
+ def _init_weights(self, m):
448
+ if isinstance(m, nn.Linear):
449
+ trunc_normal_(m.weight, std=.02)
450
+ if isinstance(m, nn.Linear) and m.bias is not None:
451
+ nn.init.constant_(m.bias, 0)
452
+ elif isinstance(m, nn.LayerNorm):
453
+ nn.init.constant_(m.bias, 0)
454
+ nn.init.constant_(m.weight, 1.0)
455
+ elif isinstance(m, nn.Conv2d):
456
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
457
+ fan_out //= m.groups
458
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
459
+ if m.bias is not None:
460
+ m.bias.data.zero_()
461
+
462
+ def forward_cls(self, x):
463
+ B, N, C = x.shape
464
+ cls_tokens = x.mean(dim=1, keepdim=True)
465
+ x = torch.cat((cls_tokens, x), dim=1)
466
+ for block in self.post_network:
467
+ x = block(x)
468
+ return x
469
+
470
+ # def forward_features(self, x):
471
+ # B = x.shape[0]
472
+ # for i in range(self.num_stages):
473
+ # patch_embed = getattr(self, f"patch_embed{i + 1}")
474
+ # block = getattr(self, f"block{i + 1}")
475
+ # x, H, W = patch_embed(x)
476
+ # for blk in block:
477
+ # x = blk(x, H, W)
478
+ # tokens = x
479
+
480
+ # if i != self.num_stages - 1:
481
+ # norm = getattr(self, f"norm{i + 1}")
482
+ # x = norm(x)
483
+ # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
484
+
485
+ # x = self.forward_cls(x)[:, 0]
486
+ # norm = getattr(self, f"norm{self.num_stages}")
487
+ # x = norm(x)
488
+ # return x, tokens
489
+
490
+ def forward_features(self, x):
491
+ B = x.shape[0]
492
+ attention_maps = [] # Collect attention maps if available
493
+ tokens = None # Initialize tokens to ensure scope coverage
494
+
495
+ for i in range(self.num_stages):
496
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
497
+ block = getattr(self, f"block{i + 1}")
498
+ x, H, W = patch_embed(x)
499
+
500
+ for blk in block:
501
+ outputs = blk(x, H, W)
502
+ if isinstance(outputs, tuple):
503
+ x, attn_weights = outputs
504
+ attention_maps.append(attn_weights) # Store attention maps
505
+ else:
506
+ x = outputs
507
+
508
+ tokens = x # Update tokens with the latest block output
509
+
510
+ if i != self.num_stages - 1:
511
+ norm = getattr(self, f"norm{i + 1}")
512
+ x = norm(x)
513
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
514
+
515
+ x = self.forward_cls(x)[:, 0] # Further processing for classification token
516
+ norm = getattr(self, f"norm{self.num_stages}")
517
+ x = norm(x)
518
+ return x, tokens, attention_maps
519
+
520
+
521
+
522
+ # def forward(self, x):
523
+ # if not self.return_dense:
524
+ # x, tokens = self.forward_features(x)
525
+ # x = self.head(x)
526
+ # return x, tokens
527
+ # else:
528
+ # x, H, W = self.forward_embeddings(x)
529
+ # # mix token, see token labeling for details.
530
+ # if self.mix_token and self.training:
531
+ # lam = np.random.beta(self.beta, self.beta)
532
+ # patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
533
+ # 2] // self.pooling_scale
534
+ # bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
535
+ # temp_x = x.clone()
536
+ # sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\
537
+ # self.pooling_scale*bbx2,self.pooling_scale*bby2
538
+ # temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
539
+ # x = temp_x
540
+ # else:
541
+ # bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
542
+
543
+ # x = self.forward_tokens(x, H, W)
544
+ # x_cls = self.head(x[:, 0])
545
+ # x_aux = self.aux_head(
546
+ # x[:, 1:]
547
+ # ) # generate classes in all feature tokens, see token labeling
548
+
549
+ # if not self.training:
550
+ # return x_cls + 0.5 * x_aux.max(1)[0]
551
+
552
+ # if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
553
+ # x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
554
+
555
+ # temp_x = x_aux.clone()
556
+ # temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
557
+ # x_aux = temp_x
558
+
559
+ # x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
560
+
561
+ # return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
562
+
563
+
564
+ def forward(self, x):
565
+ attention_maps = [] # Initialize to collect attention maps from all blocks
566
+
567
+ if not self.return_dense:
568
+ # Retrieve main output, tokens, and attention maps
569
+ x, tokens, new_attention_maps = self.forward_features(x)
570
+ attention_maps.extend(new_attention_maps) # Collect new attention maps
571
+ x = self.head(x)
572
+ return x, tokens, attention_maps
573
+ else:
574
+ # For dense token labeling and feature manipulation
575
+ x, H, W = self.forward_embeddings(x)
576
+ x, new_attention_maps = self.forward_tokens(x, H, W) # Adjusted to return attention maps
577
+ attention_maps.extend(new_attention_maps) # Collect new attention maps
578
+
579
+ if self.mix_token and self.training:
580
+ lam = np.random.beta(self.beta, self.beta)
581
+ patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
582
+ bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
583
+ sbbx1, sbby1, sbbx2, sbby2 = self.pooling_scale * bbx1, self.pooling_scale * bby1, self.pooling_scale * bbx2, self.pooling_scale * bby2
584
+ temp_x = x.clone()
585
+ temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
586
+ x = temp_x
587
+ else:
588
+ bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 # Default to zero if no mixing
589
+
590
+ x_cls = self.head(x[:, 0])
591
+ x_aux = self.aux_head(x[:, 1:]) # Class prediction for all feature tokens
592
+
593
+ if not self.training:
594
+ return x_cls + 0.5 * x_aux.max(1)[0], attention_maps
595
+
596
+ return x_cls, x_aux, (bbx1, bby1, bbx2, bby2), attention_maps
597
+
598
+
599
+
600
+
601
+
602
+
603
+
604
+ def forward_tokens(self, x, H, W):
605
+ B = x.shape[0]
606
+ x = x.view(B, -1, x.size(-1))
607
+
608
+ for i in range(self.num_stages):
609
+ if i != 0:
610
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
611
+ x, H, W = patch_embed(x)
612
+
613
+ block = getattr(self, f"block{i + 1}")
614
+ for blk in block:
615
+ x = blk(x, H, W)
616
+
617
+ if i != self.num_stages - 1:
618
+ norm = getattr(self, f"norm{i + 1}")
619
+ x = norm(x)
620
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
621
+
622
+ x = self.forward_cls(x)
623
+ norm = getattr(self, f"norm{self.num_stages}")
624
+ x = norm(x)
625
+ return x
626
+
627
+ def forward_embeddings(self, x):
628
+ patch_embed = getattr(self, f"patch_embed{0 + 1}")
629
+ x, H, W = patch_embed(x)
630
+ x = x.view(x.size(0), H, W, -1)
631
+ return x, H, W
632
+
633
+
634
+ class DWConv(nn.Module):
635
+ def __init__(self, dim=768):
636
+ super(DWConv, self).__init__()
637
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
638
+
639
+ def forward(self, x, H, W):
640
+ B, N, C = x.shape
641
+ x = x.transpose(1, 2).view(B, C, H, W)
642
+ x = self.dwconv(x)
643
+ x = x.flatten(2).transpose(1, 2)
644
+ return x
645
+
646
+
647
+ @register_model
648
+ def spectformer_t_d(pretrained=False, **kwargs):
649
+ model = SpectFormer(
650
+ stem_hidden_dim = 32,
651
+ embed_dims = [64, 128, 160, 400], #64, 128, 320, 448 -----[64, 128, 160, 200]
652
+ num_heads = [2, 4, 10, 16], #2, 4, 10, 16 ----------[2, 4, 10, 10]
653
+ mlp_ratios = [8, 8, 4, 4],
654
+ norm_layer = partial(nn.LayerNorm, eps=1e-6),
655
+ depths = [1, 2, 5, 2], #1, 2, 3, 1 ---------[1, 1, 1, 1]
656
+ sr_ratios = [4, 2, 1, 1],
657
+ **kwargs)
658
+ model.default_cfg = _cfg()
659
+ return model
660
+
661
+ @register_model
662
+ def spectformer_t_w(pretrained=False, **kwargs):
663
+ model = SpectFormer(
664
+ stem_hidden_dim = 32,
665
+ embed_dims = [64, 128, 320, 96], #64, 128, 320, 448 -----[64, 128, 160, 200]
666
+ num_heads = [2, 4, 10, 16], #2, 4, 10, 16 ----------[2, 4, 10, 10]
667
+ mlp_ratios = [8, 8, 4, 4],
668
+ norm_layer = partial(nn.LayerNorm, eps=1e-6),
669
+ depths = [1, 1, 1, 1], #1, 2, 3, 1 ---------[1, 1, 1, 1]
670
+ sr_ratios = [4, 2, 1, 1],
671
+ **kwargs)
672
+ model.default_cfg = _cfg()
673
+ return model