HichTala commited on
Commit
02988cf
·
verified ·
1 Parent(s): 09be33c

Delete diffusiondet/loss.py

Browse files
Files changed (1) hide show
  1. diffusiondet/loss.py +0 -415
diffusiondet/loss.py DELETED
@@ -1,415 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from fvcore.nn import sigmoid_focal_loss_jit
4
- from torch import nn
5
-
6
- import torch.distributed as dist
7
- from torch.distributed import get_world_size
8
- from torchvision import ops
9
-
10
-
11
- def is_dist_avail_and_initialized():
12
- if not dist.is_available():
13
- return False
14
- if not dist.is_initialized():
15
- return False
16
- return True
17
-
18
-
19
- def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight):
20
- """
21
- Args:
22
- gt_classes: a long tensor of shape R that contains the gt class label of each proposal.
23
- num_fed_loss_classes: minimum number of classes to keep when calculating federated loss.
24
- Will sample negative classes if number of unique gt_classes is smaller than this value.
25
- num_classes: number of foreground classes
26
- weight: probabilities used to sample negative classes
27
- Returns:
28
- Tensor:
29
- classes to keep when calculating the federated loss, including both unique gt
30
- classes and sampled negative classes.
31
- """
32
- unique_gt_classes = torch.unique(gt_classes)
33
- prob = unique_gt_classes.new_ones(num_classes + 1).float()
34
- prob[-1] = 0
35
- if len(unique_gt_classes) < num_fed_loss_classes:
36
- prob[:num_classes] = weight.float().clone()
37
- prob[unique_gt_classes] = 0
38
- sampled_negative_classes = torch.multinomial(
39
- prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False
40
- )
41
- fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes])
42
- else:
43
- fed_loss_classes = unique_gt_classes
44
- return fed_loss_classes
45
-
46
-
47
- class CriterionDynamicK(nn.Module):
48
- """ This class computes the loss for DiffusionDet.
49
- The process happens in two steps:
50
- 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
51
- 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
52
- """
53
-
54
- def __init__(self, config, num_classes, weight_dict):
55
- """ Create the criterion.
56
- Parameters:
57
- num_classes: number of object categories, omitting the special no-object category
58
- weight_dict: dict containing as key the names of the losses and as values their relative weight.
59
- """
60
- super().__init__()
61
- self.config = config
62
- self.num_classes = num_classes
63
- self.matcher = HungarianMatcherDynamicK(config)
64
- self.weight_dict = weight_dict
65
- self.eos_coef = config.no_object_weight
66
- self.use_focal = config.use_focal
67
- self.use_fed_loss = config.use_fed_loss
68
-
69
- if self.use_focal:
70
- self.focal_loss_alpha = config.alpha
71
- self.focal_loss_gamma = config.gamma
72
-
73
- # copy-paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/roi_heads/fast_rcnn.py#L356
74
- def loss_labels(self, outputs, targets, indices):
75
- """Classification loss (NLL)
76
- targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
77
- """
78
- assert 'pred_logits' in outputs
79
- src_logits = outputs['pred_logits']
80
- batch_size = len(targets)
81
-
82
- # idx = self._get_src_permutation_idx(indices)
83
- # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
84
- target_classes = torch.full(src_logits.shape[:2], self.num_classes,
85
- dtype=torch.int64, device=src_logits.device)
86
- src_logits_list = []
87
- target_classes_o_list = []
88
- # target_classes[idx] = target_classes_o
89
- for batch_idx in range(batch_size):
90
- valid_query = indices[batch_idx][0]
91
- gt_multi_idx = indices[batch_idx][1]
92
- if len(gt_multi_idx) == 0:
93
- continue
94
- bz_src_logits = src_logits[batch_idx]
95
- target_classes_o = targets[batch_idx]["labels"]
96
- target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx]
97
-
98
- src_logits_list.append(bz_src_logits[valid_query])
99
- target_classes_o_list.append(target_classes_o[gt_multi_idx])
100
-
101
- if self.use_focal or self.use_fed_loss:
102
- num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1
103
-
104
- target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1],
105
- dtype=src_logits.dtype, layout=src_logits.layout,
106
- device=src_logits.device)
107
- target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
108
-
109
- gt_classes = torch.argmax(target_classes_onehot, dim=-1)
110
- target_classes_onehot = target_classes_onehot[:, :, :-1]
111
-
112
- src_logits = src_logits.flatten(0, 1)
113
- target_classes_onehot = target_classes_onehot.flatten(0, 1)
114
- if self.use_focal:
115
- cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha,
116
- gamma=self.focal_loss_gamma, reduction="none")
117
- else:
118
- cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none")
119
- if self.use_fed_loss:
120
- K = self.num_classes
121
- N = src_logits.shape[0]
122
- fed_loss_classes = get_fed_loss_classes(
123
- gt_classes,
124
- num_fed_loss_classes=self.fed_loss_num_classes,
125
- num_classes=K,
126
- weight=self.fed_loss_cls_weights,
127
- )
128
- fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1)
129
- fed_loss_classes_mask[fed_loss_classes] = 1
130
- fed_loss_classes_mask = fed_loss_classes_mask[:K]
131
- weight = fed_loss_classes_mask.view(1, K).expand(N, K).float()
132
-
133
- loss_ce = torch.sum(cls_loss * weight) / num_boxes
134
- else:
135
- loss_ce = torch.sum(cls_loss) / num_boxes
136
-
137
- losses = {'loss_ce': loss_ce}
138
- else:
139
- raise NotImplementedError
140
-
141
- return losses
142
-
143
- def loss_boxes(self, outputs, targets, indices):
144
- """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
145
- targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
146
- The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
147
- """
148
- assert 'pred_boxes' in outputs
149
- # idx = self._get_src_permutation_idx(indices)
150
- src_boxes = outputs['pred_boxes']
151
-
152
- batch_size = len(targets)
153
- pred_box_list = []
154
- pred_norm_box_list = []
155
- tgt_box_list = []
156
- tgt_box_xyxy_list = []
157
- for batch_idx in range(batch_size):
158
- valid_query = indices[batch_idx][0]
159
- gt_multi_idx = indices[batch_idx][1]
160
- if len(gt_multi_idx) == 0:
161
- continue
162
- bz_image_whwh = targets[batch_idx]['image_size_xyxy']
163
- bz_src_boxes = src_boxes[batch_idx]
164
- bz_target_boxes = targets[batch_idx]["boxes"] # normalized (cx, cy, w, h)
165
- bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] # absolute (x1, y1, x2, y2)
166
- pred_box_list.append(bz_src_boxes[valid_query])
167
- pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) # normalize (x1, y1, x2, y2)
168
- tgt_box_list.append(bz_target_boxes[gt_multi_idx])
169
- tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx])
170
-
171
- if len(pred_box_list) != 0:
172
- src_boxes = torch.cat(pred_box_list)
173
- src_boxes_norm = torch.cat(pred_norm_box_list) # normalized (x1, y1, x2, y2)
174
- target_boxes = torch.cat(tgt_box_list)
175
- target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list)
176
- num_boxes = src_boxes.shape[0]
177
-
178
- losses = {}
179
- # require normalized (x1, y1, x2, y2)
180
- loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none')
181
- losses['loss_bbox'] = loss_bbox.sum() / num_boxes
182
-
183
- # loss_giou = giou_loss(box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))
184
- loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy))
185
- losses['loss_giou'] = loss_giou.sum() / num_boxes
186
- else:
187
- losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0,
188
- 'loss_giou': outputs['pred_boxes'].sum() * 0}
189
-
190
- return losses
191
-
192
- def get_loss(self, loss, outputs, targets, indices):
193
- loss_map = {
194
- 'labels': self.loss_labels,
195
- 'boxes': self.loss_boxes,
196
- }
197
- assert loss in loss_map, f'do you really want to compute {loss} loss?'
198
- return loss_map[loss](outputs, targets, indices)
199
-
200
- def forward(self, outputs, targets):
201
- """ This performs the loss computation.
202
- Parameters:
203
- outputs: dict of tensors, see the output specification of the model for the format
204
- targets: list of dicts, such that len(targets) == batch_size.
205
- The expected keys in each dict depends on the losses applied, see each loss' doc
206
- """
207
- outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
208
-
209
- # Retrieve the matching between the outputs of the last layer and the targets
210
- indices, _ = self.matcher(outputs_without_aux, targets)
211
-
212
- # Compute all the requested losses
213
- losses = {}
214
- for loss in ["labels", "boxes"]:
215
- losses.update(self.get_loss(loss, outputs, targets, indices))
216
-
217
- # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
218
- if 'aux_outputs' in outputs:
219
- for i, aux_outputs in enumerate(outputs['aux_outputs']):
220
- indices, _ = self.matcher(aux_outputs, targets)
221
- for loss in ["labels", "boxes"]:
222
- if loss == 'masks':
223
- # Intermediate masks losses are too costly to compute, we ignore them.
224
- continue
225
-
226
- l_dict = self.get_loss(loss, aux_outputs, targets, indices)
227
- l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
228
- losses.update(l_dict)
229
-
230
- return losses
231
-
232
-
233
- def get_in_boxes_info(boxes, target_gts):
234
- xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') # (x1, y1, x2, y2)
235
-
236
- anchor_center_x = boxes[:, 0].unsqueeze(1)
237
- anchor_center_y = boxes[:, 1].unsqueeze(1)
238
-
239
- # whether the center of each anchor is inside a gt box
240
- b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0)
241
- b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0)
242
- b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0)
243
- b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0)
244
- # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] ,
245
- is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
246
- is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query]
247
- # in fixed center
248
- center_radius = 2.5
249
- # Modified to self-adapted sampling --- the center size depends on the size of the gt boxes
250
- # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212
251
- b_l = anchor_center_x > (
252
- target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
253
- b_r = anchor_center_x < (
254
- target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
255
- b_t = anchor_center_y > (
256
- target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
257
- b_b = anchor_center_y < (
258
- target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
259
-
260
- is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
261
- is_in_centers_all = is_in_centers.sum(1) > 0
262
-
263
- is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
264
- is_in_boxes_and_center = (is_in_boxes & is_in_centers)
265
-
266
- return is_in_boxes_anchor, is_in_boxes_and_center
267
-
268
-
269
- class HungarianMatcherDynamicK(nn.Module):
270
- """This class computes an assignment between the targets and the predictions of the network
271
- For efficiency reasons, the targets don't include the no_object. Because of this, in general,
272
- there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions,
273
- while the others are un-matched (and thus treated as non-objects).
274
- """
275
-
276
- def __init__(self, config):
277
- super().__init__()
278
- self.use_focal = config.use_focal
279
- self.use_fed_loss = config.use_fed_loss
280
- self.cost_class = config.class_weight
281
- self.cost_giou = config.giou_weight
282
- self.cost_bbox = config.l1_weight
283
- self.ota_k = config.ota_k
284
-
285
- if self.use_focal:
286
- self.focal_loss_alpha = config.alpha
287
- self.focal_loss_gamma = config.gamma
288
-
289
- assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
290
-
291
- def forward(self, outputs, targets):
292
- """ simOTA for detr"""
293
- with torch.no_grad():
294
- bs, num_queries = outputs["pred_logits"].shape[:2]
295
- # We flatten to compute the cost matrices in a batch
296
- if self.use_focal or self.use_fed_loss:
297
- out_prob = outputs["pred_logits"].sigmoid() # [batch_size, num_queries, num_classes]
298
- out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
299
- else:
300
- out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes]
301
- out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
302
-
303
- indices = []
304
- matched_ids = []
305
- assert bs == len(targets)
306
- for batch_idx in range(bs):
307
- bz_boxes = out_bbox[batch_idx] # [num_proposals, 4]
308
- bz_out_prob = out_prob[batch_idx]
309
- bz_tgt_ids = targets[batch_idx]["labels"]
310
- num_insts = len(bz_tgt_ids)
311
- if num_insts == 0: # empty object in key frame
312
- non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0
313
- indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob))
314
- matched_qidx = torch.arange(0, 0).to(bz_out_prob)
315
- indices.append(indices_batchi)
316
- matched_ids.append(matched_qidx)
317
- continue
318
-
319
- bz_gtboxs = targets[batch_idx]['boxes'] # [num_gt, 4] normalized (cx, xy, w, h)
320
- bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy']
321
- fg_mask, is_in_boxes_and_center = get_in_boxes_info(
322
- ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), # absolute (cx, cy, w, h)
323
- ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') # absolute (cx, cy, w, h)
324
- )
325
-
326
- pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
327
-
328
- # Compute the classification cost.
329
- if self.use_focal:
330
- alpha = self.focal_loss_alpha
331
- gamma = self.focal_loss_gamma
332
- neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log())
333
- pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log())
334
- cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
335
- elif self.use_fed_loss:
336
- # focal loss degenerates to naive one
337
- neg_cost_class = (-(1 - bz_out_prob + 1e-8).log())
338
- pos_cost_class = (-(bz_out_prob + 1e-8).log())
339
- cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
340
- else:
341
- cost_class = -bz_out_prob[:, bz_tgt_ids]
342
-
343
- # Compute the L1 cost between boxes
344
- # image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets])
345
- # image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1)
346
- # image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
347
-
348
- bz_image_size_out = targets[batch_idx]['image_size_xyxy']
349
- bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt']
350
-
351
- bz_out_bbox_ = bz_boxes / bz_image_size_out # normalize (x1, y1, x2, y2)
352
- bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt # normalize (x1, y1, x2, y2)
353
- cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1)
354
-
355
- cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
356
-
357
- # Final cost matrix
358
- cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * (
359
- ~is_in_boxes_and_center)
360
- # cost = (cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) # [num_query,num_gt]
361
- cost[~fg_mask] = cost[~fg_mask] + 10000.0
362
-
363
- # if bz_gtboxs.shape[0]>0:
364
- indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0])
365
-
366
- indices.append(indices_batchi)
367
- matched_ids.append(matched_qidx)
368
-
369
- return indices, matched_ids
370
-
371
- def dynamic_k_matching(self, cost, pair_wise_ious, num_gt):
372
- matching_matrix = torch.zeros_like(cost) # [300,num_gt]
373
- ious_in_boxes_matrix = pair_wise_ious
374
- n_candidate_k = self.ota_k
375
-
376
- # Take the sum of the predicted value and the top 10 iou of gt with the largest iou as dynamic_k
377
- topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0)
378
- dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
379
-
380
- for gt_idx in range(num_gt):
381
- _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
382
- matching_matrix[:, gt_idx][pos_idx] = 1.0
383
-
384
- del topk_ious, dynamic_ks, pos_idx
385
-
386
- anchor_matching_gt = matching_matrix.sum(1)
387
-
388
- if (anchor_matching_gt > 1).sum() > 0:
389
- _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
390
- matching_matrix[anchor_matching_gt > 1] *= 0
391
- matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1
392
-
393
- while (matching_matrix.sum(0) == 0).any():
394
- num_zero_gt = (matching_matrix.sum(0) == 0).sum()
395
- matched_query_id = matching_matrix.sum(1) > 0
396
- cost[matched_query_id] += 100000.0
397
- unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
398
- for gt_idx in unmatch_id:
399
- pos_idx = torch.argmin(cost[:, gt_idx])
400
- matching_matrix[:, gt_idx][pos_idx] = 1.0
401
- if (matching_matrix.sum(1) > 1).sum() > 0: # If a query matches more than one gt
402
- _, cost_argmin = torch.min(cost[anchor_matching_gt > 1],
403
- dim=1) # find gt for these queries with minimal cost
404
- matching_matrix[anchor_matching_gt > 1] *= 0 # reset mapping relationship
405
- matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 # keep gt with minimal cost
406
-
407
- assert not (matching_matrix.sum(0) == 0).any()
408
- selected_query = matching_matrix.sum(1) > 0
409
- gt_indices = matching_matrix[selected_query].max(1)[1]
410
- assert selected_query.sum() == len(gt_indices)
411
-
412
- cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf')
413
- matched_query_id = torch.min(cost, dim=0)[1]
414
-
415
- return (selected_query, gt_indices), matched_query_id