Spaces:
Running
on
T4
Running
on
T4
| from torchvision.datasets.vision import VisionDataset | |
| import os.path | |
| from typing import Callable, Optional | |
| import json | |
| from PIL import Image | |
| import torch | |
| import random | |
| import os, sys | |
| sys.path.append(os.path.dirname(sys.path[0])) | |
| import datasets.transforms as T | |
| class ODVGDataset(VisionDataset): | |
| """ | |
| Args: | |
| root (string): Root directory where images are downloaded to. | |
| anno (string): Path to json annotation file. | |
| label_map_anno (string): Path to json label mapping file. Only for Object Detection | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.PILToTensor`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
| and returns a transformed version. | |
| """ | |
| def __init__( | |
| self, | |
| root: str, | |
| anno: str, | |
| label_map_anno: str = None, | |
| max_labels: int = 80, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| transforms: Optional[Callable] = None, | |
| ) -> None: | |
| super().__init__(root, transforms, transform, target_transform) | |
| self.root = root | |
| self.dataset_mode = "OD" if label_map_anno else "VG" | |
| self.max_labels = max_labels | |
| if self.dataset_mode == "OD": | |
| self.load_label_map(label_map_anno) | |
| self._load_metas(anno) | |
| self.get_dataset_info() | |
| def load_label_map(self, label_map_anno): | |
| with open(label_map_anno, "r") as file: | |
| self.label_map = json.load(file) | |
| self.label_index = set(self.label_map.keys()) | |
| def _load_metas(self, anno): | |
| with open(anno, "r") as f: | |
| self.metas = [json.loads(line) for line in f] | |
| def get_dataset_info(self): | |
| print(f" == total images: {len(self)}") | |
| if self.dataset_mode == "OD": | |
| print(f" == total labels: {len(self.label_map)}") | |
| def __getitem__(self, index: int): | |
| meta = self.metas[index] | |
| rel_path = meta["filename"] | |
| abs_path = os.path.join(self.root, rel_path) | |
| if not os.path.exists(abs_path): | |
| raise FileNotFoundError(f"{abs_path} not found.") | |
| image = Image.open(abs_path).convert("RGB") | |
| exemplars = torch.tensor(meta["exemplars"], dtype=torch.int64) | |
| w, h = image.size | |
| if self.dataset_mode == "OD": | |
| anno = meta["detection"] | |
| instances = [obj for obj in anno["instances"]] | |
| boxes = [obj["bbox"] for obj in instances] | |
| # generate vg_labels | |
| # pos bbox labels | |
| ori_classes = [str(obj["label"]) for obj in instances] | |
| pos_labels = set(ori_classes) | |
| # neg bbox labels | |
| neg_labels = self.label_index.difference(pos_labels) | |
| vg_labels = list(pos_labels) | |
| num_to_add = min(len(neg_labels), self.max_labels - len(pos_labels)) | |
| if num_to_add > 0: | |
| vg_labels.extend(random.sample(neg_labels, num_to_add)) | |
| # shuffle | |
| for i in range(len(vg_labels) - 1, 0, -1): | |
| j = random.randint(0, i) | |
| vg_labels[i], vg_labels[j] = vg_labels[j], vg_labels[i] | |
| caption_list = [self.label_map[lb] for lb in vg_labels] | |
| caption_dict = {item: index for index, item in enumerate(caption_list)} | |
| caption = " . ".join(caption_list) + " ." | |
| classes = [ | |
| caption_dict[self.label_map[str(obj["label"])]] for obj in instances | |
| ] | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| elif self.dataset_mode == "VG": | |
| anno = meta["grounding"] | |
| instances = [obj for obj in anno["regions"]] | |
| boxes = [obj["bbox"] for obj in instances] | |
| caption_list = [obj["phrase"] for obj in instances] | |
| c = list(zip(boxes, caption_list)) | |
| random.shuffle(c) | |
| boxes[:], caption_list[:] = zip(*c) | |
| uni_caption_list = list(set(caption_list)) | |
| label_map = {} | |
| for idx in range(len(uni_caption_list)): | |
| label_map[uni_caption_list[idx]] = idx | |
| classes = [label_map[cap] for cap in caption_list] | |
| caption = " . ".join(uni_caption_list) + " ." | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| caption_list = uni_caption_list | |
| target = {} | |
| target["size"] = torch.as_tensor([int(h), int(w)]) | |
| target["cap_list"] = caption_list | |
| target["caption"] = caption | |
| target["boxes"] = boxes | |
| target["labels"] = classes | |
| target["exemplars"] = exemplars | |
| target["labels_uncropped"] = torch.clone(classes) | |
| # size, cap_list, caption, bboxes, labels | |
| if self.transforms is not None: | |
| image, target = self.transforms(image, target) | |
| # Check that transforms does not change the identity of target['labels']. | |
| if len(target["labels"]) > 0: | |
| assert target["labels"][0] == target["labels_uncropped"][0] | |
| print( | |
| "Asserted that transforms does not change the identity of target['labels']." | |
| ) | |
| return image, target | |
| def __len__(self) -> int: | |
| return len(self.metas) | |
| def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None): | |
| normalize = T.Compose( | |
| [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] | |
| ) | |
| # config the params for data aug | |
| scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | |
| max_size = 1333 | |
| scales2_resize = [400, 500, 600] | |
| scales2_crop = [384, 600] | |
| # update args from config files | |
| scales = getattr(args, "data_aug_scales", scales) | |
| max_size = getattr(args, "data_aug_max_size", max_size) | |
| scales2_resize = getattr(args, "data_aug_scales2_resize", scales2_resize) | |
| scales2_crop = getattr(args, "data_aug_scales2_crop", scales2_crop) | |
| # resize them | |
| data_aug_scale_overlap = getattr(args, "data_aug_scale_overlap", None) | |
| if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0: | |
| data_aug_scale_overlap = float(data_aug_scale_overlap) | |
| scales = [int(i * data_aug_scale_overlap) for i in scales] | |
| max_size = int(max_size * data_aug_scale_overlap) | |
| scales2_resize = [int(i * data_aug_scale_overlap) for i in scales2_resize] | |
| scales2_crop = [int(i * data_aug_scale_overlap) for i in scales2_crop] | |
| # datadict_for_print = { | |
| # 'scales': scales, | |
| # 'max_size': max_size, | |
| # 'scales2_resize': scales2_resize, | |
| # 'scales2_crop': scales2_crop | |
| # } | |
| # print("data_aug_params:", json.dumps(datadict_for_print, indent=2)) | |
| if image_set == "train": | |
| if fix_size: | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomResize([(max_size, max(scales))]), | |
| normalize, | |
| ] | |
| ) | |
| if strong_aug: | |
| import datasets.sltransform as SLT | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomSelect( | |
| T.RandomResize(scales, max_size=max_size), | |
| T.Compose( | |
| [ | |
| T.RandomResize(scales2_resize), | |
| T.RandomSizeCrop(*scales2_crop), | |
| T.RandomResize(scales, max_size=max_size), | |
| ] | |
| ), | |
| ), | |
| SLT.RandomSelectMulti( | |
| [ | |
| SLT.RandomCrop(), | |
| SLT.LightingNoise(), | |
| SLT.AdjustBrightness(2), | |
| SLT.AdjustContrast(2), | |
| ] | |
| ), | |
| normalize, | |
| ] | |
| ) | |
| return T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(), | |
| T.RandomSelect( | |
| T.RandomResize(scales, max_size=max_size), | |
| T.Compose( | |
| [ | |
| T.RandomResize(scales2_resize), | |
| T.RandomSizeCrop(*scales2_crop), | |
| T.RandomResize(scales, max_size=max_size), | |
| ] | |
| ), | |
| ), | |
| normalize, | |
| ] | |
| ) | |
| if image_set in ["val", "eval_debug", "train_reg", "test"]: | |
| if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == "INFO": | |
| print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!") | |
| return T.Compose( | |
| [ | |
| T.ResizeDebug((1280, 800)), | |
| normalize, | |
| ] | |
| ) | |
| return T.Compose( | |
| [ | |
| T.RandomResize([max(scales)], max_size=max_size), | |
| normalize, | |
| ] | |
| ) | |
| raise ValueError(f"unknown {image_set}") | |
| def build_odvg(image_set, args, datasetinfo): | |
| img_folder = datasetinfo["root"] | |
| ann_file = datasetinfo["anno"] | |
| label_map = datasetinfo["label_map"] if "label_map" in datasetinfo else None | |
| try: | |
| strong_aug = args.strong_aug | |
| except: | |
| strong_aug = False | |
| print(img_folder, ann_file, label_map) | |
| dataset = ODVGDataset( | |
| img_folder, | |
| ann_file, | |
| label_map, | |
| max_labels=args.max_labels, | |
| transforms=make_coco_transforms( | |
| image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args | |
| ), | |
| ) | |
| return dataset | |
| if __name__ == "__main__": | |
| dataset_vg = ODVGDataset( | |
| "path/GRIT-20M/data/", | |
| "path/GRIT-20M/anno/grit_odvg_10k.jsonl", | |
| ) | |
| print(len(dataset_vg)) | |
| data = dataset_vg[random.randint(0, 100)] | |
| print(data) | |
| dataset_od = ODVGDataset( | |
| "pathl/V3Det/", | |
| "path/V3Det/annotations/v3det_2023_v1_all_odvg.jsonl", | |
| "path/V3Det/annotations/v3det_label_map.json", | |
| ) | |
| print(len(dataset_od)) | |
| data = dataset_od[random.randint(0, 100)] | |
| print(data) | |