Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import os | |
| import pandas as pd | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| from skimage.transform import resize | |
| import torch | |
| from torchvision import transforms | |
| class AugmentedImageSequence(Dataset): | |
| """ | |
| Thread-safe image generator with imgaug support in PyTorch | |
| """ | |
| def __init__(self, dataset_csv_file, class_names, source_image_dir, tokenizer_wrapper, batch_size=16, | |
| target_size=(224, 224), augmenter=None, verbose=0, steps=None, | |
| shuffle_on_epoch_end=True, random_state=1): | |
| """ | |
| :param dataset_csv_file: str, path of dataset csv file | |
| :param class_names: list of str | |
| :param batch_size: int | |
| :param target_size: tuple(int, int) | |
| :param augmenter: imgaug object. Do not specify resize in augmenter. | |
| It will be done automatically according to input_shape of the model. | |
| :param verbose: int | |
| """ | |
| self.dataset_df = pd.read_csv(dataset_csv_file) | |
| self.source_image_dir = source_image_dir | |
| self.batch_size = batch_size | |
| self.target_size = target_size | |
| self.augmenter = augmenter | |
| self.tokenizer_wrapper = tokenizer_wrapper | |
| self.verbose = verbose | |
| self.shuffle = shuffle_on_epoch_end | |
| self.random_state = random_state | |
| self.class_names = class_names | |
| self.prepare_dataset() | |
| if steps is None: | |
| self.steps = int(np.ceil(len(self.x_path) / float(self.batch_size))) | |
| else: | |
| self.steps = int(steps) | |
| self.transform = transforms.Compose([ | |
| # transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def __len__(self): | |
| return self.steps | |
| def __getitem__(self, idx): | |
| batch_x_path = self.x_path[idx * self.batch_size:(idx + 1) * self.batch_size] | |
| batch_x = torch.stack([self.load_image(x_path) for x_path in batch_x_path]) | |
| batch_x = self.transform_batch_images(batch_x) | |
| batch_y = torch.tensor(self.y[idx * self.batch_size:(idx + 1) * self.batch_size]) | |
| return batch_x, batch_y, batch_x_path.tolist() | |
| def load_image(self, image_file): | |
| image_path = os.path.join(self.source_image_dir, image_file) | |
| image = Image.open(image_path).convert("RGB") | |
| image_array = np.asarray(image) / 255. | |
| image_array = resize(image_array, self.target_size) | |
| image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1) # Convert to CxHxW | |
| return image_tensor | |
| def transform_batch_images(self, batch_x): | |
| if self.augmenter is not None: | |
| batch_x = torch.stack([torch.tensor(self.augmenter.augment_image(img.permute(1, 2, 0).numpy())) for img in batch_x]) | |
| batch_x = self.transform(batch_x) | |
| return batch_x | |
| def get_y_true(self): | |
| """ | |
| Use this function to get y_true for DataLoader. | |
| Ensure shuffle_on_epoch_end is False before using. | |
| """ | |
| if self.shuffle: | |
| raise ValueError("get_y_true() can only be used when shuffle_on_epoch_end is False.") | |
| return torch.tensor(self.y[:self.steps * self.batch_size], dtype=torch.float32) | |
| def prepare_dataset(self): | |
| df = self.dataset_df.sample(frac=1., random_state=self.random_state) | |
| ## @TODO: tokenize the targets | |
| self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.GPT2_encode( | |
| df[self.class_names].values) | |
| def on_epoch_end(self): | |
| if self.shuffle: | |
| self.random_state += 1 | |
| self.prepare_dataset() | |