HichTala commited on
Commit
09be33c
·
verified ·
1 Parent(s): f15b97f

Delete diffusiondet/modeling_diffusiondet.py

Browse files
Files changed (1) hide show
  1. diffusiondet/modeling_diffusiondet.py +0 -424
diffusiondet/modeling_diffusiondet.py DELETED
@@ -1,424 +0,0 @@
1
- import math
2
- import random
3
- from collections import namedtuple, OrderedDict
4
- from dataclasses import dataclass
5
- from typing import Dict, List, Optional, Tuple, Union
6
-
7
- import torch
8
- from torch import nn
9
- import torch.nn.functional as F
10
- from torchvision import ops
11
- from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
12
- from transformers import PreTrainedModel
13
- import wandb
14
-
15
- from transformers.utils.backbone_utils import load_backbone
16
- from .configuration_diffusiondet import DiffusionDetConfig
17
-
18
- from .head import HeadDynamicK
19
- from .loss import CriterionDynamicK
20
-
21
- from transformers.utils import ModelOutput
22
-
23
- ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
24
-
25
-
26
- def default(val, d):
27
- if val is not None:
28
- return val
29
- return d() if callable(d) else d
30
-
31
-
32
- def extract(a, t, x_shape):
33
- """extract the appropriate t index for a batch of indices"""
34
- batch_size = t.shape[0]
35
- out = a.gather(-1, t)
36
- return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
37
-
38
-
39
- def cosine_beta_schedule(timesteps, s=0.008):
40
- """
41
- cosine schedule
42
- as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
43
- """
44
- steps = timesteps + 1
45
- x = torch.linspace(0, timesteps, steps)
46
- alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
47
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
48
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
49
- return torch.clip(betas, 0, 0.999)
50
-
51
- @dataclass
52
- class DiffusionDetOutput(ModelOutput):
53
- """
54
- Output type of DiffusionDet.
55
- """
56
-
57
- loss: Optional[torch.FloatTensor] = None
58
- loss_dict: Optional[Dict] = None
59
- logits: torch.FloatTensor = None
60
- labels: torch.IntTensor = None
61
- pred_boxes: torch.FloatTensor = None
62
-
63
- class DiffusionDet(PreTrainedModel):
64
- """
65
- Implement DiffusionDet
66
- """
67
- config_class = DiffusionDetConfig
68
- main_input_name = "pixel_values"
69
-
70
- def __init__(self, config):
71
- super(DiffusionDet, self).__init__(config)
72
-
73
- self.in_features = config.roi_head_in_features
74
- self.num_classes = config.num_labels
75
- self.num_proposals = config.num_proposals
76
- self.num_heads = config.num_heads
77
-
78
- self.backbone = load_backbone(config)
79
- self.fpn = FeaturePyramidNetwork(
80
- in_channels_list=self.backbone.channels,
81
- out_channels=config.fpn_out_channels,
82
- # extra_blocks=LastLevelMaxPool(),
83
- )
84
-
85
- # build diffusion
86
- betas = cosine_beta_schedule(1000)
87
- alphas_cumprod = torch.cumprod(1 - betas, dim=0)
88
-
89
- timesteps, = betas.shape
90
- sampling_timesteps = config.sample_step
91
-
92
- self.register_buffer('alphas_cumprod', alphas_cumprod)
93
- self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
94
- self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
95
- self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
96
- self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
97
-
98
- self.num_timesteps = int(timesteps)
99
- self.sampling_timesteps = default(sampling_timesteps, timesteps)
100
- self.ddim_sampling_eta = 1.
101
- self.scale = config.snr_scale
102
- assert self.sampling_timesteps <= timesteps
103
-
104
- roi_input_shape = {
105
- 'p2': {'stride': 4},
106
- 'p3': {'stride': 8},
107
- 'p4': {'stride': 16},
108
- 'p5': {'stride': 32},
109
- 'p6': {'stride': 64}
110
- }
111
- self.head = HeadDynamicK(config, roi_input_shape=roi_input_shape)
112
-
113
- self.deep_supervision = config.deep_supervision
114
- self.use_focal = config.use_focal
115
- self.use_fed_loss = config.use_fed_loss
116
- self.use_nms = config.use_nms
117
-
118
- weight_dict = {
119
- "loss_ce": config.class_weight, "loss_bbox": config.l1_weight, "loss_giou": config.giou_weight
120
- }
121
- if self.deep_supervision:
122
- aux_weight_dict = {}
123
- for i in range(self.num_heads - 1):
124
- aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
125
- weight_dict.update(aux_weight_dict)
126
-
127
- self.criterion = CriterionDynamicK(config, num_classes=self.num_classes, weight_dict=weight_dict)
128
-
129
- def predict_noise_from_start(self, x_t, t, x0):
130
- return (
131
- (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
132
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
133
- )
134
-
135
- def model_predictions(self, backbone_feats, images_whwh, x, t):
136
- x_boxes = torch.clamp(x, min=-1 * self.scale, max=self.scale)
137
- x_boxes = ((x_boxes / self.scale) + 1) / 2
138
- x_boxes = ops.box_convert(x_boxes, 'cxcywh', 'xyxy')
139
- x_boxes = x_boxes * images_whwh[:, None, :]
140
- outputs_class, outputs_coord = self.head(backbone_feats, x_boxes, t)
141
-
142
- x_start = outputs_coord[-1] # (batch, num_proposals, 4) predict boxes: absolute coordinates (x1, y1, x2, y2)
143
- x_start = x_start / images_whwh[:, None, :]
144
- x_start = ops.box_convert(x_start, 'xyxy', 'cxcywh')
145
- x_start = (x_start * 2 - 1.) * self.scale
146
- x_start = torch.clamp(x_start, min=-1 * self.scale, max=self.scale)
147
- pred_noise = self.predict_noise_from_start(x, t, x_start)
148
-
149
- return ModelPrediction(pred_noise, x_start), outputs_class, outputs_coord
150
-
151
- @torch.no_grad()
152
- def ddim_sample(self, batched_inputs, backbone_feats, images_whwh):
153
- bs = len(batched_inputs)
154
- image_sizes = batched_inputs.shape
155
- shape = (bs, self.num_proposals, 4)
156
-
157
- # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
158
- times = torch.linspace(-1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1)
159
- times = list(reversed(times.int().tolist()))
160
- time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
161
-
162
- img = torch.randn(shape, device=self.device)
163
-
164
- ensemble_score, ensemble_label, ensemble_coord = [], [], []
165
- outputs_class, outputs_coord = None, None
166
- for time, time_next in time_pairs:
167
- time_cond = torch.full((bs,), time, device=self.device, dtype=torch.long)
168
-
169
- preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond)
170
- pred_noise, x_start = preds.pred_noise, preds.pred_x_start
171
-
172
- score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
173
- threshold = 0.5
174
- score_per_image = torch.sigmoid(score_per_image)
175
- value, _ = torch.max(score_per_image, -1, keepdim=False)
176
- keep_idx = value > threshold
177
- num_remain = torch.sum(keep_idx)
178
-
179
- pred_noise = pred_noise[:, keep_idx, :]
180
- x_start = x_start[:, keep_idx, :]
181
- img = img[:, keep_idx, :]
182
-
183
- if time_next < 0:
184
- img = x_start
185
- continue
186
-
187
- alpha = self.alphas_cumprod[time]
188
- alpha_next = self.alphas_cumprod[time_next]
189
-
190
- sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
191
- c = (1 - alpha_next - sigma ** 2).sqrt()
192
-
193
- noise = torch.randn_like(img)
194
-
195
- img = x_start * alpha_next.sqrt() + \
196
- c * pred_noise + \
197
- sigma * noise
198
-
199
- img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
200
-
201
- if self.sampling_timesteps > 1:
202
- box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
203
- outputs_coord[-1])
204
- ensemble_score.append(scores_per_image)
205
- ensemble_label.append(labels_per_image)
206
- ensemble_coord.append(box_pred_per_image)
207
-
208
- if self.sampling_timesteps > 1:
209
- box_pred_per_image = torch.cat(ensemble_coord, dim=0)
210
- scores_per_image = torch.cat(ensemble_score, dim=0)
211
- labels_per_image = torch.cat(ensemble_label, dim=0)
212
-
213
- if self.use_nms:
214
- keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
215
- box_pred_per_image = box_pred_per_image[keep]
216
- scores_per_image = scores_per_image[keep]
217
- labels_per_image = labels_per_image[keep]
218
-
219
- return box_pred_per_image, scores_per_image, labels_per_image
220
- else:
221
- return self.inference(outputs_class[-1], outputs_coord[-1])
222
-
223
- def q_sample(self, x_start, t, noise=None):
224
- if noise is None:
225
- noise = torch.randn_like(x_start)
226
-
227
- sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
228
- sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
229
-
230
- return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
231
-
232
- def forward(self, pixel_values, labels):
233
- """
234
- Args:
235
- """
236
- images = pixel_values.to(self.device)
237
- images_whwh = list()
238
- for image in images:
239
- h, w = image.shape[-2:]
240
- images_whwh.append(torch.tensor([w, h, w, h], device=self.device))
241
- images_whwh = torch.stack(images_whwh)
242
-
243
- features = self.backbone(images)
244
- features = OrderedDict(
245
- [(key, feature) for key, feature in zip(self.backbone.out_features, features.feature_maps)]
246
- )
247
- features = self.fpn(features) # [144, 72, 36, 18]
248
- features = [features[f] for f in features.keys()]
249
-
250
- # if self.training:
251
- labels = list(map(lambda tensor: tensor.to(self.device), labels))
252
- targets, x_boxes, noises, ts = self.prepare_targets(labels)
253
-
254
- ts = ts.squeeze(-1)
255
- x_boxes = x_boxes * images_whwh[:, None, :]
256
-
257
- outputs_class, outputs_coord = self.head(features, x_boxes, ts)
258
- output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
259
-
260
- if self.deep_supervision:
261
- output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
262
- for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
263
-
264
- loss_dict = self.criterion(output, targets)
265
- weight_dict = self.criterion.weight_dict
266
- for k in loss_dict.keys():
267
- if k in weight_dict:
268
- loss_dict[k] *= weight_dict[k]
269
- loss_dict['loss'] = sum([loss_dict[k] for k in weight_dict.keys()])
270
-
271
- wandb_logs_values = ["loss_ce", "loss_bbox", "loss_giou"]
272
-
273
- if self.training:
274
- wandb.log({f'train/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
275
- else:
276
- wandb.log({f'eval/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
277
-
278
- if not self.training:
279
- pred_logits, pred_labels, pred_boxes = self.ddim_sample(pixel_values, features, images_whwh)
280
- return DiffusionDetOutput(
281
- loss=loss_dict['loss'],
282
- loss_dict=loss_dict,
283
- logits=pred_logits,
284
- labels=pred_labels,
285
- pred_boxes=pred_boxes,
286
- )
287
-
288
- return DiffusionDetOutput(
289
- loss=loss_dict['loss'],
290
- loss_dict=loss_dict,
291
- logits=output['pred_logits'],
292
- pred_boxes=output['pred_boxes']
293
- )
294
-
295
- def prepare_diffusion_concat(self, gt_boxes):
296
- """
297
- :param gt_boxes: (cx, cy, w, h), normalized
298
- :param num_proposals:
299
- """
300
- t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
301
- noise = torch.randn(self.num_proposals, 4, device=self.device)
302
-
303
- num_gt = gt_boxes.shape[0]
304
- if not num_gt: # generate fake gt boxes if empty gt boxes
305
- gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
306
- num_gt = 1
307
-
308
- if num_gt < self.num_proposals:
309
- box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
310
- device=self.device) / 6. + 0.5 # 3sigma = 1/2 --> sigma: 1/6
311
- box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
312
- x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
313
- elif num_gt > self.num_proposals:
314
- select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
315
- random.shuffle(select_mask)
316
- x_start = gt_boxes[select_mask]
317
- else:
318
- x_start = gt_boxes
319
-
320
- x_start = (x_start * 2. - 1.) * self.scale
321
-
322
- # noise sample
323
- x = self.q_sample(x_start=x_start, t=t, noise=noise)
324
-
325
- x = torch.clamp(x, min=-1 * self.scale, max=self.scale)
326
- x = ((x / self.scale) + 1) / 2.
327
-
328
- diff_boxes = ops.box_convert(x, 'cxcywh', 'xyxy')
329
-
330
- return diff_boxes, noise, t
331
-
332
- def prepare_targets(self, targets):
333
- new_targets = []
334
- diffused_boxes = []
335
- noises = []
336
- ts = []
337
- for target in targets:
338
- h, w = target.size
339
- image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
340
- gt_classes = target.class_labels.to(self.device)
341
- gt_boxes = target.boxes.to(self.device)
342
- d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)
343
- image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
344
- gt_boxes = gt_boxes * image_size_xyxy
345
- gt_boxes = ops.box_convert(gt_boxes, 'cxcywh', 'xyxy')
346
-
347
- diffused_boxes.append(d_boxes)
348
- noises.append(d_noise)
349
- ts.append(d_t)
350
- new_targets.append({
351
- "labels": gt_classes,
352
- "boxes": target.boxes.to(self.device),
353
- "boxes_xyxy": gt_boxes,
354
- "image_size_xyxy": image_size_xyxy.to(self.device),
355
- "image_size_xyxy_tgt": image_size_xyxy_tgt.to(self.device),
356
- "area": ops.box_area(target.boxes.to(self.device)),
357
- })
358
-
359
- return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
360
-
361
- def inference(self, box_cls, box_pred):
362
- """
363
- Arguments:
364
- box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
365
- The tensor predicts the classification probability for each proposal.
366
- box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
367
- The tensor predicts 4-vector (x,y,w,h) box
368
- regression values for every proposal
369
- image_sizes (List[torch.Size]): the input image sizes
370
-
371
- Returns:
372
- results (List[Instances]): a list of #images elements.
373
- """
374
- results = []
375
- boxes_output = []
376
- logits_output = []
377
- labels_output = []
378
-
379
- if self.use_focal or self.use_fed_loss:
380
- scores = torch.sigmoid(box_cls)
381
- labels = torch.arange(self.num_classes, device=self.device). \
382
- unsqueeze(0).repeat(self.num_proposals, 1).flatten(0, 1)
383
-
384
- for i, (scores_per_image, box_pred_per_image) in enumerate(zip(
385
- scores, box_pred
386
- )):
387
- scores_per_image, topk_indices = scores_per_image.flatten(0, 1).topk(self.num_proposals, sorted=False)
388
- labels_per_image = labels[topk_indices]
389
- box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(1, self.num_classes, 1).view(-1, 4)
390
- box_pred_per_image = box_pred_per_image[topk_indices]
391
-
392
- if self.sampling_timesteps > 1:
393
- return box_pred_per_image, scores_per_image, labels_per_image
394
-
395
- if self.use_nms:
396
- keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
397
- box_pred_per_image = box_pred_per_image[keep]
398
- scores_per_image = scores_per_image[keep]
399
- labels_per_image = labels_per_image[keep]
400
-
401
- boxes_output.append(box_pred_per_image)
402
- logits_output.append(scores_per_image)
403
- labels_output.append(labels_per_image)
404
- else:
405
- # For each box we assign the best class or the second best if the best on is `no_object`.
406
- scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
407
-
408
- for i, (scores_per_image, labels_per_image, box_pred_per_image) in enumerate(zip(
409
- scores, labels, box_pred
410
- )):
411
- if self.sampling_timesteps > 1:
412
- return box_pred_per_image, scores_per_image, labels_per_image
413
-
414
- if self.use_nms:
415
- keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
416
- box_pred_per_image = box_pred_per_image[keep]
417
- scores_per_image = scores_per_image[keep]
418
- labels_per_image = labels_per_image[keep]
419
-
420
- boxes_output.append(box_pred_per_image)
421
- logits_output.append(scores_per_image)
422
- labels_output.append(labels_per_image)
423
-
424
- return boxes_output, logits_output, labels_output