CDGPT2-Deployment / generator.py
Ziad Meligy
Pushing deployment to space
eb8805a
raw
history blame
3.71 kB
import numpy as np
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage.transform import resize
import torch
from torchvision import transforms
class AugmentedImageSequence(Dataset):
"""
Thread-safe image generator with imgaug support in PyTorch
"""
def __init__(self, dataset_csv_file, class_names, source_image_dir, tokenizer_wrapper, batch_size=16,
target_size=(224, 224), augmenter=None, verbose=0, steps=None,
shuffle_on_epoch_end=True, random_state=1):
"""
:param dataset_csv_file: str, path of dataset csv file
:param class_names: list of str
:param batch_size: int
:param target_size: tuple(int, int)
:param augmenter: imgaug object. Do not specify resize in augmenter.
It will be done automatically according to input_shape of the model.
:param verbose: int
"""
self.dataset_df = pd.read_csv(dataset_csv_file)
self.source_image_dir = source_image_dir
self.batch_size = batch_size
self.target_size = target_size
self.augmenter = augmenter
self.tokenizer_wrapper = tokenizer_wrapper
self.verbose = verbose
self.shuffle = shuffle_on_epoch_end
self.random_state = random_state
self.class_names = class_names
self.prepare_dataset()
if steps is None:
self.steps = int(np.ceil(len(self.x_path) / float(self.batch_size)))
else:
self.steps = int(steps)
self.transform = transforms.Compose([
# transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return self.steps
def __getitem__(self, idx):
batch_x_path = self.x_path[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_x = torch.stack([self.load_image(x_path) for x_path in batch_x_path])
batch_x = self.transform_batch_images(batch_x)
batch_y = torch.tensor(self.y[idx * self.batch_size:(idx + 1) * self.batch_size])
return batch_x, batch_y, batch_x_path.tolist()
def load_image(self, image_file):
image_path = os.path.join(self.source_image_dir, image_file)
image = Image.open(image_path).convert("RGB")
image_array = np.asarray(image) / 255.
image_array = resize(image_array, self.target_size)
image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1) # Convert to CxHxW
return image_tensor
def transform_batch_images(self, batch_x):
if self.augmenter is not None:
batch_x = torch.stack([torch.tensor(self.augmenter.augment_image(img.permute(1, 2, 0).numpy())) for img in batch_x])
batch_x = self.transform(batch_x)
return batch_x
def get_y_true(self):
"""
Use this function to get y_true for DataLoader.
Ensure shuffle_on_epoch_end is False before using.
"""
if self.shuffle:
raise ValueError("get_y_true() can only be used when shuffle_on_epoch_end is False.")
return torch.tensor(self.y[:self.steps * self.batch_size], dtype=torch.float32)
def prepare_dataset(self):
df = self.dataset_df.sample(frac=1., random_state=self.random_state)
## @TODO: tokenize the targets
self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.GPT2_encode(
df[self.class_names].values)
def on_epoch_end(self):
if self.shuffle:
self.random_state += 1
self.prepare_dataset()