| from datasets import load_dataset | |
| from linear_mapping import LinearMapping, LinearMappingProcessor, LinearMappingConfig, Transform | |
| import torch | |
| from torchvision.io import ImageReadMode, read_image | |
| from transformers import Trainer, TrainingArguments | |
| import os | |
| from PIL import Image | |
| os.environ["WANDB_DISABLED"] = "true" | |
| DATA_DIR = os.path.join(os.getcwd(), "coco") | |
| CAPTION_COLUMN = "caption" | |
| IMAGE_COLUMN = "image_path" | |
| def main(): | |
| ds = load_dataset("ydshieh/coco_dataset_script", "2017", DATA_DIR) | |
| config = LinearMappingConfig() | |
| processor = LinearMappingProcessor(config) | |
| def collate_fn(batch): | |
| return { | |
| 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), | |
| 'input_ids': torch.tensor([x['input_ids'] for x in batch], dtype=torch.long), | |
| 'attention_mask': torch.stack([x["attention_mask"] for x in batch]), | |
| } | |
| def tokenize_fn(examples): | |
| texts = list(examples[CAPTION_COLUMN]) | |
| if config.add_image_token: | |
| texts = list(processor.tokenizer.cls_token + text for text in texts) | |
| inputs = processor.tokenizer( | |
| texts, padding="max_length", max_length=77, | |
| return_tensors="pt", truncation=True | |
| ) | |
| examples["input_ids"] = inputs.input_ids | |
| examples["attention_mask"] = inputs.attention_mask | |
| return examples | |
| image_transformations = Transform( | |
| config.image_resize, | |
| [0.48145466, 0.4578275, 0.40821073], | |
| [0.26862954, 0.26130258, 0.27577711] | |
| ) | |
| image_transformations = torch.jit.script(image_transformations) | |
| def transform_images(examples): | |
| images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]] | |
| examples["pixel_values"] = [image_transformations(image) for image in images] | |
| examples["attention_mask"] = torch.cat([ | |
| torch.ones(len(images), config.prefix_length), | |
| torch.tensor(examples["attention_mask"]) | |
| ], dim=1).to(dtype=torch.long) | |
| return examples | |
| def preprocess_fn(examples): | |
| texts = list(examples[CAPTION_COLUMN]) | |
| images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]] | |
| inputs = processor( | |
| texts=texts, images=images, padding="max_length", truncation=True, max_length=77, return_tensors="pt" | |
| ) | |
| return inputs | |
| def filter_corrupt_images(examples): | |
| """remove problematic images""" | |
| valid_images = [] | |
| for image_file in examples[IMAGE_COLUMN]: | |
| try: | |
| Image.open(image_file) | |
| valid_images.append(True) | |
| except Exception: | |
| valid_images.append(False) | |
| return valid_images | |
| train_dataset = ds["train"] | |
| train_dataset = train_dataset.filter( | |
| function=filter_corrupt_images, | |
| batched=True | |
| ) | |
| train_dataset = train_dataset.map( | |
| function=tokenize_fn, | |
| batched=True, | |
| remove_columns=[col for col in train_dataset.column_names if col != IMAGE_COLUMN and col != CAPTION_COLUMN], | |
| load_from_cache_file=True | |
| ) | |
| train_dataset.set_transform(transform_images) | |
| training_args = TrainingArguments( | |
| learning_rate=5e-4, | |
| lr_scheduler_type='cosine', | |
| output_dir='clip-gpt2-image-captioner', | |
| do_train=True, | |
| logging_steps=50, | |
| num_train_epochs=5, | |
| logging_dir='runs', | |
| remove_unused_columns=False, | |
| max_grad_norm=1.0, | |
| per_device_train_batch_size=16, | |
| save_total_limit=3, | |
| warmup_steps=500 | |
| ) | |
| model = LinearMapping(config) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| data_collator=collate_fn | |
| ) | |
| trainer.train() | |
| if __name__ == '__main__': | |
| main() | |