Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| """ | |
| COCO dataset which returns image_id for evaluation. | |
| Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | |
| """ | |
| if __name__ == "__main__": | |
| # for debug only | |
| import os, sys | |
| sys.path.append(os.path.dirname(sys.path[0])) | |
| from torchvision.datasets.vision import VisionDataset | |
| import json | |
| from pathlib import Path | |
| import random | |
| import os | |
| from typing import Any, Callable, List, Optional, Tuple | |
| from PIL import Image | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| from pycocotools import mask as coco_mask | |
| from datasets.data_util import preparing_dataset | |
| import datasets.transforms as T | |
| from util.box_ops import box_cxcywh_to_xyxy, box_iou | |
| __all__ = ["build"] | |
| class label2compat: | |
| def __init__(self) -> None: | |
| self.category_map_str = { | |
| "1": 1, | |
| "2": 2, | |
| "3": 3, | |
| "4": 4, | |
| "5": 5, | |
| "6": 6, | |
| "7": 7, | |
| "8": 8, | |
| "9": 9, | |
| "10": 10, | |
| "11": 11, | |
| "13": 12, | |
| "14": 13, | |
| "15": 14, | |
| "16": 15, | |
| "17": 16, | |
| "18": 17, | |
| "19": 18, | |
| "20": 19, | |
| "21": 20, | |
| "22": 21, | |
| "23": 22, | |
| "24": 23, | |
| "25": 24, | |
| "27": 25, | |
| "28": 26, | |
| "31": 27, | |
| "32": 28, | |
| "33": 29, | |
| "34": 30, | |
| "35": 31, | |
| "36": 32, | |
| "37": 33, | |
| "38": 34, | |
| "39": 35, | |
| "40": 36, | |
| "41": 37, | |
| "42": 38, | |
| "43": 39, | |
| "44": 40, | |
| "46": 41, | |
| "47": 42, | |
| "48": 43, | |
| "49": 44, | |
| "50": 45, | |
| "51": 46, | |
| "52": 47, | |
| "53": 48, | |
| "54": 49, | |
| "55": 50, | |
| "56": 51, | |
| "57": 52, | |
| "58": 53, | |
| "59": 54, | |
| "60": 55, | |
| "61": 56, | |
| "62": 57, | |
| "63": 58, | |
| "64": 59, | |
| "65": 60, | |
| "67": 61, | |
| "70": 62, | |
| "72": 63, | |
| "73": 64, | |
| "74": 65, | |
| "75": 66, | |
| "76": 67, | |
| "77": 68, | |
| "78": 69, | |
| "79": 70, | |
| "80": 71, | |
| "81": 72, | |
| "82": 73, | |
| "84": 74, | |
| "85": 75, | |
| "86": 76, | |
| "87": 77, | |
| "88": 78, | |
| "89": 79, | |
| "90": 80, | |
| } | |
| self.category_map = {int(k): v for k, v in self.category_map_str.items()} | |
| def __call__(self, target, img=None): | |
| labels = target["labels"] | |
| res = torch.zeros(labels.shape, dtype=labels.dtype) | |
| for idx, item in enumerate(labels): | |
| res[idx] = self.category_map[item.item()] - 1 | |
| target["label_compat"] = res | |
| if img is not None: | |
| return target, img | |
| else: | |
| return target | |
| class label_compat2onehot: | |
| def __init__(self, num_class=80, num_output_objs=1): | |
| self.num_class = num_class | |
| self.num_output_objs = num_output_objs | |
| if num_output_objs != 1: | |
| raise DeprecationWarning( | |
| "num_output_objs!=1, which is only used for comparison" | |
| ) | |
| def __call__(self, target, img=None): | |
| labels = target["label_compat"] | |
| place_dict = {k: 0 for k in range(self.num_class)} | |
| if self.num_output_objs == 1: | |
| res = torch.zeros(self.num_class) | |
| for i in labels: | |
| itm = i.item() | |
| res[itm] = 1.0 | |
| else: | |
| # compat with baseline | |
| res = torch.zeros(self.num_class, self.num_output_objs) | |
| for i in labels: | |
| itm = i.item() | |
| res[itm][place_dict[itm]] = 1.0 | |
| place_dict[itm] += 1 | |
| target["label_compat_onehot"] = res | |
| if img is not None: | |
| return target, img | |
| else: | |
| return target | |
| class box_label_catter: | |
| def __init__(self): | |
| pass | |
| def __call__(self, target, img=None): | |
| labels = target["label_compat"] | |
| boxes = target["boxes"] | |
| box_label = torch.cat((boxes, labels.unsqueeze(-1)), 1) | |
| target["box_label"] = box_label | |
| if img is not None: | |
| return target, img | |
| else: | |
| return target | |
| class RandomSelectBoxlabels: | |
| def __init__( | |
| self, | |
| num_classes, | |
| leave_one_out=False, | |
| blank_prob=0.8, | |
| prob_first_item=0.0, | |
| prob_random_item=0.0, | |
| prob_last_item=0.8, | |
| prob_stop_sign=0.2, | |
| ) -> None: | |
| self.num_classes = num_classes | |
| self.leave_one_out = leave_one_out | |
| self.blank_prob = blank_prob | |
| self.set_state( | |
| prob_first_item, prob_random_item, prob_last_item, prob_stop_sign | |
| ) | |
| def get_state(self): | |
| return [ | |
| self.prob_first_item, | |
| self.prob_random_item, | |
| self.prob_last_item, | |
| self.prob_stop_sign, | |
| ] | |
| def set_state( | |
| self, prob_first_item, prob_random_item, prob_last_item, prob_stop_sign | |
| ): | |
| sum_prob = prob_first_item + prob_random_item + prob_last_item + prob_stop_sign | |
| assert sum_prob - 1 < 1e-6, ( | |
| f"Sum up all prob = {sum_prob}. prob_first_item:{prob_first_item}" | |
| + f"prob_random_item:{prob_random_item}, prob_last_item:{prob_last_item}" | |
| + f"prob_stop_sign:{prob_stop_sign}" | |
| ) | |
| self.prob_first_item = prob_first_item | |
| self.prob_random_item = prob_random_item | |
| self.prob_last_item = prob_last_item | |
| self.prob_stop_sign = prob_stop_sign | |
| def sample_for_pred_first_item(self, box_label: torch.FloatTensor): | |
| box_label_known = torch.Tensor(0, 5) | |
| box_label_unknown = box_label | |
| return box_label_known, box_label_unknown | |
| def sample_for_pred_random_item(self, box_label: torch.FloatTensor): | |
| n_select = int(random.random() * box_label.shape[0]) | |
| box_label = box_label[torch.randperm(box_label.shape[0])] | |
| box_label_known = box_label[:n_select] | |
| box_label_unknown = box_label[n_select:] | |
| return box_label_known, box_label_unknown | |
| def sample_for_pred_last_item(self, box_label: torch.FloatTensor): | |
| box_label_perm = box_label[torch.randperm(box_label.shape[0])] | |
| known_label_list = [] | |
| box_label_known = [] | |
| box_label_unknown = [] | |
| for item in box_label_perm: | |
| label_i = item[4].item() | |
| if label_i in known_label_list: | |
| box_label_known.append(item) | |
| else: | |
| # first item | |
| box_label_unknown.append(item) | |
| known_label_list.append(label_i) | |
| box_label_known = ( | |
| torch.stack(box_label_known) | |
| if len(box_label_known) > 0 | |
| else torch.Tensor(0, 5) | |
| ) | |
| box_label_unknown = ( | |
| torch.stack(box_label_unknown) | |
| if len(box_label_unknown) > 0 | |
| else torch.Tensor(0, 5) | |
| ) | |
| return box_label_known, box_label_unknown | |
| def sample_for_pred_stop_sign(self, box_label: torch.FloatTensor): | |
| box_label_unknown = torch.Tensor(0, 5) | |
| box_label_known = box_label | |
| return box_label_known, box_label_unknown | |
| def __call__(self, target, img=None): | |
| box_label = target["box_label"] # K, 5 | |
| dice_number = random.random() | |
| if dice_number < self.prob_first_item: | |
| box_label_known, box_label_unknown = self.sample_for_pred_first_item( | |
| box_label | |
| ) | |
| elif dice_number < self.prob_first_item + self.prob_random_item: | |
| box_label_known, box_label_unknown = self.sample_for_pred_random_item( | |
| box_label | |
| ) | |
| elif ( | |
| dice_number | |
| < self.prob_first_item + self.prob_random_item + self.prob_last_item | |
| ): | |
| box_label_known, box_label_unknown = self.sample_for_pred_last_item( | |
| box_label | |
| ) | |
| else: | |
| box_label_known, box_label_unknown = self.sample_for_pred_stop_sign( | |
| box_label | |
| ) | |
| target["label_onehot_known"] = label2onehot( | |
| box_label_known[:, -1], self.num_classes | |
| ) | |
| target["label_onehot_unknown"] = label2onehot( | |
| box_label_unknown[:, -1], self.num_classes | |
| ) | |
| target["box_label_known"] = box_label_known | |
| target["box_label_unknown"] = box_label_unknown | |
| return target, img | |
| class RandomDrop: | |
| def __init__(self, p=0.2) -> None: | |
| self.p = p | |
| def __call__(self, target, img=None): | |
| known_box = target["box_label_known"] | |
| num_known_box = known_box.size(0) | |
| idxs = torch.rand(num_known_box) | |
| # indices = torch.randperm(num_known_box)[:int((1-self).p*num_known_box + 0.5 + random.random())] | |
| target["box_label_known"] = known_box[idxs > self.p] | |
| return target, img | |
| class BboxPertuber: | |
| def __init__(self, max_ratio=0.02, generate_samples=1000) -> None: | |
| self.max_ratio = max_ratio | |
| self.generate_samples = generate_samples | |
| self.samples = self.generate_pertube_samples() | |
| self.idx = 0 | |
| def generate_pertube_samples(self): | |
| import torch | |
| samples = (torch.rand(self.generate_samples, 5) - 0.5) * 2 * self.max_ratio | |
| return samples | |
| def __call__(self, target, img): | |
| known_box = target["box_label_known"] # Tensor(K,5), K known bbox | |
| K = known_box.shape[0] | |
| known_box_pertube = torch.zeros(K, 6) # 4:bbox, 1:prob, 1:label | |
| if K == 0: | |
| pass | |
| else: | |
| if self.idx + K > self.generate_samples: | |
| self.idx = 0 | |
| delta = self.samples[self.idx : self.idx + K, :] | |
| known_box_pertube[:, :4] = known_box[:, :4] + delta[:, :4] | |
| iou = ( | |
| torch.diag( | |
| box_iou( | |
| box_cxcywh_to_xyxy(known_box[:, :4]), | |
| box_cxcywh_to_xyxy(known_box_pertube[:, :4]), | |
| )[0] | |
| ) | |
| ) * (1 + delta[:, -1]) | |
| known_box_pertube[:, 4].copy_(iou) | |
| known_box_pertube[:, -1].copy_(known_box[:, -1]) | |
| target["box_label_known_pertube"] = known_box_pertube | |
| return target, img | |
| class RandomCutout: | |
| def __init__(self, factor=0.5) -> None: | |
| self.factor = factor | |
| def __call__(self, target, img=None): | |
| unknown_box = target["box_label_unknown"] # Ku, 5 | |
| known_box = target["box_label_known_pertube"] # Kk, 6 | |
| Ku = unknown_box.size(0) | |
| known_box_add = torch.zeros(Ku, 6) # Ku, 6 | |
| known_box_add[:, :5] = unknown_box | |
| known_box_add[:, 5].uniform_(0.5, 1) | |
| known_box_add[:, :2] += known_box_add[:, 2:4] * (torch.rand(Ku, 2) - 0.5) / 2 | |
| known_box_add[:, 2:4] /= 2 | |
| target["box_label_known_pertube"] = torch.cat((known_box, known_box_add)) | |
| return target, img | |
| class RandomSelectBoxes: | |
| def __init__(self, num_class=80) -> None: | |
| Warning("This is such a slow function and will be deprecated soon!!!") | |
| self.num_class = num_class | |
| def __call__(self, target, img=None): | |
| boxes = target["boxes"] | |
| labels = target["label_compat"] | |
| # transform to list of tensors | |
| boxs_list = [[] for i in range(self.num_class)] | |
| for idx, item in enumerate(boxes): | |
| label = labels[idx].item() | |
| boxs_list[label].append(item) | |
| boxs_list_tensor = [ | |
| torch.stack(i) if len(i) > 0 else torch.Tensor(0, 4) for i in boxs_list | |
| ] | |
| # random selection | |
| box_known = [] | |
| box_unknown = [] | |
| for idx, item in enumerate(boxs_list_tensor): | |
| ncnt = item.shape[0] | |
| nselect = int( | |
| random.random() * ncnt | |
| ) # close in both sides, much faster than random.randint | |
| item = item[torch.randperm(ncnt)] | |
| # random.shuffle(item) | |
| box_known.append(item[:nselect]) | |
| box_unknown.append(item[nselect:]) | |
| # box_known_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_known] | |
| # box_unknown_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_unknown] | |
| # print('box_unknown_tensor:', box_unknown_tensor) | |
| target["known_box"] = box_known | |
| target["unknown_box"] = box_unknown | |
| return target, img | |
| def label2onehot(label, num_classes): | |
| """ | |
| label: Tensor(K) | |
| """ | |
| res = torch.zeros(num_classes) | |
| for i in label: | |
| itm = int(i.item()) | |
| res[itm] = 1.0 | |
| return res | |
| class MaskCrop: | |
| def __init__(self) -> None: | |
| pass | |
| def __call__(self, target, img): | |
| known_box = target["known_box"] | |
| h, w = img.shape[1:] # h,w | |
| # imgsize = target['orig_size'] # h,w | |
| scale = torch.Tensor([w, h, w, h]) | |
| # _cnt = 0 | |
| for boxes in known_box: | |
| if boxes.shape[0] == 0: | |
| continue | |
| box_xyxy = box_cxcywh_to_xyxy(boxes) * scale | |
| for box in box_xyxy: | |
| x1, y1, x2, y2 = [int(i) for i in box.tolist()] | |
| img[:, y1:y2, x1:x2] = 0 | |
| # _cnt += 1 | |
| # print("_cnt:", _cnt) | |
| return target, img | |
| dataset_hook_register = { | |
| "label2compat": label2compat, | |
| "label_compat2onehot": label_compat2onehot, | |
| "box_label_catter": box_label_catter, | |
| "RandomSelectBoxlabels": RandomSelectBoxlabels, | |
| "RandomSelectBoxes": RandomSelectBoxes, | |
| "MaskCrop": MaskCrop, | |
| "BboxPertuber": BboxPertuber, | |
| } | |
| class CocoDetection(torchvision.datasets.CocoDetection): | |
| def __init__( | |
| self, img_folder, ann_file, transforms, return_masks, aux_target_hacks=None | |
| ): | |
| super(CocoDetection, self).__init__(img_folder, ann_file) | |
| self._transforms = transforms | |
| self.prepare = ConvertCocoPolysToMask(return_masks) | |
| self.aux_target_hacks = aux_target_hacks | |
| def change_hack_attr(self, hackclassname, attrkv_dict): | |
| target_class = dataset_hook_register[hackclassname] | |
| for item in self.aux_target_hacks: | |
| if isinstance(item, target_class): | |
| for k, v in attrkv_dict.items(): | |
| setattr(item, k, v) | |
| def get_hack(self, hackclassname): | |
| target_class = dataset_hook_register[hackclassname] | |
| for item in self.aux_target_hacks: | |
| if isinstance(item, target_class): | |
| return item | |
| def _load_image(self, id: int) -> Image.Image: | |
| path = self.coco.loadImgs(id)[0]["file_name"] | |
| abs_path = os.path.join(self.root, path) | |
| return Image.open(abs_path).convert("RGB") | |
| def __getitem__(self, idx): | |
| """ | |
| Output: | |
| - target: dict of multiple items | |
| - boxes: Tensor[num_box, 4]. \ | |
| Init type: x0,y0,x1,y1. unnormalized data. | |
| Final type: cx,cy,w,h. normalized data. | |
| """ | |
| try: | |
| img, target = super(CocoDetection, self).__getitem__(idx) | |
| except: | |
| print("Error idx: {}".format(idx)) | |
| idx += 1 | |
| img, target = super(CocoDetection, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| target = {"image_id": image_id, "annotations": target} | |
| exemp_count = 0 | |
| for instance in target["annotations"]: | |
| if instance["area"] != 4: | |
| exemp_count += 1 | |
| # Only provide at most 3 visual exemplars during inference. | |
| assert exemp_count == 3 | |
| img, target = self.prepare(img, target) | |
| target["exemplars"] = target["boxes"][-3:] | |
| # Remove inaccurate exemplars. | |
| if image_id == 6003: | |
| target["exemplars"] = torch.tensor([]) | |
| target["boxes"] = target["boxes"][:-3] | |
| target["labels"] = target["labels"][:-3] | |
| target["labels_uncropped"] = torch.clone(target["labels"]) | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| # convert to needed format | |
| if self.aux_target_hacks is not None: | |
| for hack_runner in self.aux_target_hacks: | |
| target, img = hack_runner(target, img=img) | |
| return img, target | |
| def convert_coco_poly_to_mask(segmentations, height, width): | |
| masks = [] | |
| for polygons in segmentations: | |
| rles = coco_mask.frPyObjects(polygons, height, width) | |
| mask = coco_mask.decode(rles) | |
| if len(mask.shape) < 3: | |
| mask = mask[..., None] | |
| mask = torch.as_tensor(mask, dtype=torch.uint8) | |
| mask = mask.any(dim=2) | |
| masks.append(mask) | |
| if masks: | |
| masks = torch.stack(masks, dim=0) | |
| else: | |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
| return masks | |
| class ConvertCocoPolysToMask(object): | |
| def __init__(self, return_masks=False): | |
| self.return_masks = return_masks | |
| def __call__(self, image, target): | |
| w, h = image.size | |
| image_id = target["image_id"] | |
| image_id = torch.tensor([image_id]) | |
| anno = target["annotations"] | |
| anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in anno] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| boxes[:, 2:] += boxes[:, :2] | |
| boxes[:, 0::2].clamp_(min=0, max=w) | |
| boxes[:, 1::2].clamp_(min=0, max=h) | |
| classes = [obj["category_id"] for obj in anno] | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| if self.return_masks: | |
| segmentations = [obj["segmentation"] for obj in anno] | |
| masks = convert_coco_poly_to_mask(segmentations, h, w) | |
| keypoints = None | |
| if anno and "keypoints" in anno[0]: | |
| keypoints = [obj["keypoints"] for obj in anno] | |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
| num_keypoints = keypoints.shape[0] | |
| if num_keypoints: | |
| keypoints = keypoints.view(num_keypoints, -1, 3) | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| classes = classes[keep] | |
| if self.return_masks: | |
| masks = masks[keep] | |
| if keypoints is not None: | |
| keypoints = keypoints[keep] | |
| target = {} | |
| target["boxes"] = boxes | |
| target["labels"] = classes | |
| if self.return_masks: | |
| target["masks"] = masks | |
| target["image_id"] = image_id | |
| if keypoints is not None: | |
| target["keypoints"] = keypoints | |
| # for conversion to coco api | |
| area = torch.tensor([obj["area"] for obj in anno]) | |
| iscrowd = torch.tensor( | |
| [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno] | |
| ) | |
| target["area"] = area[keep] | |
| target["iscrowd"] = iscrowd[keep] | |
| target["orig_size"] = torch.as_tensor([int(h), int(w)]) | |
| target["size"] = torch.as_tensor([int(h), int(w)]) | |
| return image, target | |
| def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None): | |
| normalize = T.Compose( | |
| [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] | |
| ) | |
| # config the params for data aug | |
| scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | |
| max_size = 1333 | |
| scales2_resize = [400, 500, 600] | |
| scales2_crop = [384, 600] | |
| # update args from config files | |
| scales = getattr(args, "data_aug_scales", scales) | |
| max_size = getattr(args, "data_aug_max_size", max_size) | |
| scales2_resize = getattr(args, "data_aug_scales2_resize", scales2_resize) | |
| scales2_crop = getattr(args, "data_aug_scales2_crop", scales2_crop) | |
| # resize them | |
| data_aug_scale_overlap = getattr(args, "data_aug_scale_overlap", None) | |
| if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0: | |
| data_aug_scale_overlap = float(data_aug_scale_overlap) | |
| scales = [int(i * data_aug_scale_overlap) for i in scales] | |
| max_size = int(max_size * data_aug_scale_overlap) | |
| scales2_resize = [int(i * data_aug_scale_overlap) for i in scales2_resize] | |
| scales2_crop = [int(i * data_aug_scale_overlap) for i in scales2_crop] | |
| datadict_for_print = { | |
| "scales": scales, | |
| "max_size": max_size, | |
| "scales2_resize": scales2_resize, | |
| "scales2_crop": scales2_crop, | |
| } | |
| # print("data_aug_params:", json.dumps(datadict_for_print, indent=2)) | |
| if image_set == "train": | |
| if fix_size: | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomResize([(max_size, max(scales))]), | |
| # T.RandomResize([(512, 512)]), | |
| normalize, | |
| ] | |
| ) | |
| if strong_aug: | |
| import datasets.sltransform as SLT | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomSelect( | |
| T.RandomResize(scales, max_size=max_size), | |
| T.Compose( | |
| [ | |
| T.RandomResize(scales2_resize), | |
| T.RandomSizeCrop(*scales2_crop), | |
| T.RandomResize(scales, max_size=max_size), | |
| ] | |
| ), | |
| ), | |
| SLT.RandomSelectMulti( | |
| [ | |
| SLT.RandomCrop(), | |
| SLT.LightingNoise(), | |
| SLT.AdjustBrightness(2), | |
| SLT.AdjustContrast(2), | |
| ] | |
| ), | |
| normalize, | |
| ] | |
| ) | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomSelect( | |
| T.RandomResize(scales, max_size=max_size), | |
| T.Compose( | |
| [ | |
| T.RandomResize(scales2_resize), | |
| T.RandomSizeCrop(*scales2_crop), | |
| T.RandomResize(scales, max_size=max_size), | |
| ] | |
| ), | |
| ), | |
| normalize, | |
| ] | |
| ) | |
| if image_set in ["val", "eval_debug", "train_reg", "test"]: | |
| if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == "INFO": | |
| print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!") | |
| return T.Compose( | |
| [ | |
| T.ResizeDebug((1280, 800)), | |
| normalize, | |
| ] | |
| ) | |
| print("max(scales): " + str(max(scales))) | |
| return T.Compose( | |
| [ | |
| T.RandomResize([max(scales)], max_size=max_size), | |
| normalize, | |
| ] | |
| ) | |
| raise ValueError(f"unknown {image_set}") | |
| def get_aux_target_hacks_list(image_set, args): | |
| if args.modelname in ["q2bs_mask", "q2bs"]: | |
| aux_target_hacks_list = [ | |
| label2compat(), | |
| label_compat2onehot(), | |
| RandomSelectBoxes(num_class=args.num_classes), | |
| ] | |
| if args.masked_data and image_set == "train": | |
| # aux_target_hacks_list.append() | |
| aux_target_hacks_list.append(MaskCrop()) | |
| elif args.modelname in [ | |
| "q2bm_v2", | |
| "q2bs_ce", | |
| "q2op", | |
| "q2ofocal", | |
| "q2opclip", | |
| "q2ocqonly", | |
| ]: | |
| aux_target_hacks_list = [ | |
| label2compat(), | |
| label_compat2onehot(), | |
| box_label_catter(), | |
| RandomSelectBoxlabels( | |
| num_classes=args.num_classes, | |
| prob_first_item=args.prob_first_item, | |
| prob_random_item=args.prob_random_item, | |
| prob_last_item=args.prob_last_item, | |
| prob_stop_sign=args.prob_stop_sign, | |
| ), | |
| BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
| ] | |
| elif args.modelname in ["q2omask", "q2osa"]: | |
| if args.coco_aug: | |
| aux_target_hacks_list = [ | |
| label2compat(), | |
| label_compat2onehot(), | |
| box_label_catter(), | |
| RandomSelectBoxlabels( | |
| num_classes=args.num_classes, | |
| prob_first_item=args.prob_first_item, | |
| prob_random_item=args.prob_random_item, | |
| prob_last_item=args.prob_last_item, | |
| prob_stop_sign=args.prob_stop_sign, | |
| ), | |
| RandomDrop(p=0.2), | |
| BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
| RandomCutout(factor=0.5), | |
| ] | |
| else: | |
| aux_target_hacks_list = [ | |
| label2compat(), | |
| label_compat2onehot(), | |
| box_label_catter(), | |
| RandomSelectBoxlabels( | |
| num_classes=args.num_classes, | |
| prob_first_item=args.prob_first_item, | |
| prob_random_item=args.prob_random_item, | |
| prob_last_item=args.prob_last_item, | |
| prob_stop_sign=args.prob_stop_sign, | |
| ), | |
| BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
| ] | |
| else: | |
| aux_target_hacks_list = None | |
| return aux_target_hacks_list | |
| def build(image_set, args, datasetinfo): | |
| img_folder = datasetinfo["root"] | |
| ann_file = datasetinfo["anno"] | |
| # copy to local path | |
| if os.environ.get("DATA_COPY_SHILONG") == "INFO": | |
| preparing_dataset( | |
| dict(img_folder=img_folder, ann_file=ann_file), image_set, args | |
| ) | |
| try: | |
| strong_aug = args.strong_aug | |
| except: | |
| strong_aug = False | |
| print(img_folder, ann_file) | |
| dataset = CocoDetection( | |
| img_folder, | |
| ann_file, | |
| transforms=make_coco_transforms( | |
| image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args | |
| ), | |
| return_masks=args.masks, | |
| aux_target_hacks=None, | |
| ) | |
| return dataset | |
| if __name__ == "__main__": | |
| # Objects365 Val example | |
| dataset_o365 = CocoDetection( | |
| "/path/Objects365/train/", | |
| "/path/Objects365/slannos/anno_preprocess_train_v2.json", | |
| transforms=None, | |
| return_masks=False, | |
| ) | |
| print("len(dataset_o365):", len(dataset_o365)) | |