import numpy as np import os import pandas as pd from torch.utils.data import Dataset, DataLoader from PIL import Image from skimage.transform import resize import torch from torchvision import transforms class AugmentedImageSequence(Dataset): """ Thread-safe image generator with imgaug support in PyTorch """ def __init__(self, dataset_csv_file, class_names, source_image_dir, tokenizer_wrapper, batch_size=16, target_size=(224, 224), augmenter=None, verbose=0, steps=None, shuffle_on_epoch_end=True, random_state=1): """ :param dataset_csv_file: str, path of dataset csv file :param class_names: list of str :param batch_size: int :param target_size: tuple(int, int) :param augmenter: imgaug object. Do not specify resize in augmenter. It will be done automatically according to input_shape of the model. :param verbose: int """ self.dataset_df = pd.read_csv(dataset_csv_file) self.source_image_dir = source_image_dir self.batch_size = batch_size self.target_size = target_size self.augmenter = augmenter self.tokenizer_wrapper = tokenizer_wrapper self.verbose = verbose self.shuffle = shuffle_on_epoch_end self.random_state = random_state self.class_names = class_names self.prepare_dataset() if steps is None: self.steps = int(np.ceil(len(self.x_path) / float(self.batch_size))) else: self.steps = int(steps) self.transform = transforms.Compose([ # transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return self.steps def __getitem__(self, idx): batch_x_path = self.x_path[idx * self.batch_size:(idx + 1) * self.batch_size] batch_x = torch.stack([self.load_image(x_path) for x_path in batch_x_path]) batch_x = self.transform_batch_images(batch_x) batch_y = torch.tensor(self.y[idx * self.batch_size:(idx + 1) * self.batch_size]) return batch_x, batch_y, batch_x_path.tolist() def load_image(self, image_file): image_path = os.path.join(self.source_image_dir, image_file) image = Image.open(image_path).convert("RGB") image_array = np.asarray(image) / 255. image_array = resize(image_array, self.target_size) image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1) # Convert to CxHxW return image_tensor def transform_batch_images(self, batch_x): if self.augmenter is not None: batch_x = torch.stack([torch.tensor(self.augmenter.augment_image(img.permute(1, 2, 0).numpy())) for img in batch_x]) batch_x = self.transform(batch_x) return batch_x def get_y_true(self): """ Use this function to get y_true for DataLoader. Ensure shuffle_on_epoch_end is False before using. """ if self.shuffle: raise ValueError("get_y_true() can only be used when shuffle_on_epoch_end is False.") return torch.tensor(self.y[:self.steps * self.batch_size], dtype=torch.float32) def prepare_dataset(self): df = self.dataset_df.sample(frac=1., random_state=self.random_state) ## @TODO: tokenize the targets self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.GPT2_encode( df[self.class_names].values) def on_epoch_end(self): if self.shuffle: self.random_state += 1 self.prepare_dataset()