Spaces:
Runtime error
Runtime error
| import torch | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| from ldm.util import instantiate_from_config | |
| import numpy as np | |
| import random | |
| from dataset.concat_dataset import ConCatDataset #, collate_fn | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| import os | |
| from tqdm import tqdm | |
| from distributed import get_rank, synchronize, get_world_size | |
| from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap_loader #, get_padded_boxes | |
| from PIL import Image | |
| import math | |
| import json | |
| #hello | |
| def draw_masks_from_boxes(boxes,size): | |
| image_masks = [] | |
| for box in boxes: | |
| image_mask = torch.ones(size[0],size[1]) | |
| for bx in box: | |
| x0, x1 = bx[0]*size[0], bx[2]*size[0] | |
| y0, y1 = bx[1]*size[1], bx[3]*size[1] | |
| image_mask[int(y0):int(y1), int(x0):int(x1)] = 0 | |
| image_masks.append(image_mask) | |
| return torch.stack(image_masks).unsqueeze(1) | |
| def set_alpha_scale(model, alpha_scale): | |
| from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense | |
| for module in model.modules(): | |
| if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense: | |
| module.scale = alpha_scale | |
| # print("scale: ", alpha_scale) | |
| # print("attn: ", module.alpha_attn) | |
| # print("dense: ", module.alpha_dense) | |
| # print(' ') | |
| # print(' ') | |
| def save_images(samples, image_ids, folder, to256): | |
| for sample, image_id in zip(samples, image_ids): | |
| sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 | |
| sample = sample.cpu().numpy().transpose(1,2,0) * 255 | |
| img_name = str(int(image_id))+'.png' | |
| img = Image.fromarray(sample.astype(np.uint8)) | |
| if to256: | |
| img = img.resize( (256,256), Image.BICUBIC) | |
| img.save(os.path.join(folder,img_name)) | |
| def ckpt_to_folder_name(basename): | |
| name="" | |
| for s in basename: | |
| if s.isdigit(): | |
| name+=s | |
| seen = round( int(name)/1000, 1 ) | |
| return str(seen).ljust(4,'0')+'k' | |
| class Evaluator: | |
| def __init__(self, config): | |
| self.config = config | |
| self.device = torch.device("cuda") | |
| # = = = = = create model and diffusion = = = = = # | |
| if self.config.ckpt != "real": | |
| self.model = instantiate_from_config(config.model).to(self.device) | |
| self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device) | |
| self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device) | |
| self.diffusion = instantiate_from_config(config.diffusion).to(self.device) | |
| # donot need to load official_ckpt for self.model here, since we will load from our ckpt | |
| state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) ) | |
| self.autoencoder.load_state_dict( state_dict["autoencoder"] ) | |
| self.text_encoder.load_state_dict( state_dict["text_encoder"] ) | |
| self.diffusion.load_state_dict( state_dict["diffusion"] ) | |
| # = = = = = load from our ckpt = = = = = # | |
| if self.config.ckpt == "real": | |
| print("Saving all real images...") | |
| self.just_save_real = True | |
| else: | |
| checkpoint = torch.load(self.config.ckpt, map_location="cpu") | |
| which_state = 'ema' if 'ema' in checkpoint else "model" | |
| which_state = which_state if config.which_state is None else config.which_state | |
| self.model.load_state_dict(checkpoint[which_state]) | |
| print("ckpt is loaded") | |
| self.just_save_real = False | |
| set_alpha_scale(self.model, self.config.alpha_scale) | |
| self.autoencoder.eval() | |
| self.model.eval() | |
| self.text_encoder.eval() | |
| # = = = = = create data = = = = = # | |
| self.dataset_eval = ConCatDataset(config.val_dataset_names, config.DATA_ROOT, config.which_embedder, train=False) | |
| print("total eval images: ", len(self.dataset_eval)) | |
| sampler = DistributedSampler(self.dataset_eval,shuffle=False) if config.distributed else None | |
| loader_eval = DataLoader( self.dataset_eval,batch_size=config.batch_size, | |
| num_workers=config.workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| drop_last=False) # shuffle default is False | |
| self.loader_eval = loader_eval | |
| # = = = = = create output folder = = = = = # | |
| folder_name = ckpt_to_folder_name(os.path.basename(config.ckpt)) | |
| self.outdir = os.path.join(config.OUTPUT_ROOT, folder_name) | |
| self.outdir_real = os.path.join(self.outdir,'real') | |
| self.outdir_fake = os.path.join(self.outdir,'fake') | |
| if config.to256: | |
| self.outdir_real256 = os.path.join(self.outdir,'real256') | |
| self.outdir_fake256 = os.path.join(self.outdir,'fake256') | |
| synchronize() # if rank0 is faster, it may mkdir before the other rank call os.listdir() | |
| if get_rank() == 0: | |
| os.makedirs(self.outdir, exist_ok=True) | |
| os.makedirs(self.outdir_real, exist_ok=True) | |
| os.makedirs(self.outdir_fake, exist_ok=True) | |
| if config.to256: | |
| os.makedirs(self.outdir_real256, exist_ok=True) | |
| os.makedirs(self.outdir_fake256, exist_ok=True) | |
| print(self.outdir) # double check | |
| self.evaluation_finished = False | |
| if os.path.exists( os.path.join(self.outdir,'score.txt') ): | |
| self.evaluation_finished = True | |
| def alread_saved_this_batch(self, batch): | |
| existing_real_files = os.listdir( self.outdir_real ) | |
| existing_fake_files = os.listdir( self.outdir_fake ) | |
| status = [] | |
| for image_id in batch["id"]: | |
| img_name = str(int(image_id))+'.png' | |
| status.append(img_name in existing_real_files) | |
| status.append(img_name in existing_fake_files) | |
| return all(status) | |
| def start_evaluating(self): | |
| iterator = tqdm( self.loader_eval, desc='Evaluating progress') | |
| for batch in iterator: | |
| #if not self.alread_saved_this_batch(batch): | |
| if True: | |
| batch_to_device(batch, self.device) | |
| batch_size = batch["image"].shape[0] | |
| samples_real = batch["image"] | |
| if self.just_save_real: | |
| samples_fake = None | |
| else: | |
| uc = self.text_encoder.encode( batch_size*[""] ) | |
| context = self.text_encoder.encode( batch["caption"] ) | |
| image_mask = x0 = None | |
| if self.config.inpaint: | |
| image_mask = draw_masks_from_boxes( batch['boxes'], self.model.image_size ).cuda() | |
| x0 = self.autoencoder.encode( batch["image"] ) | |
| shape = (batch_size, self.model.in_channels, self.model.image_size, self.model.image_size) | |
| if self.config.no_plms: | |
| sampler = DDIMSampler(self.diffusion, self.model) | |
| steps = 250 | |
| else: | |
| sampler = PLMSSampler(self.diffusion, self.model) | |
| steps = 50 | |
| input = dict( x=None, timesteps=None, context=context, boxes=batch['boxes'], masks=batch['masks'], positive_embeddings=batch["positive_embeddings"] ) | |
| samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=self.config.guidance_scale, mask=image_mask, x0=x0) | |
| samples_fake = self.autoencoder.decode(samples_fake) | |
| save_images(samples_real, batch['id'], self.outdir_real, to256=False ) | |
| if self.config.to256: | |
| save_images(samples_real, batch['id'], self.outdir_real256, to256=True ) | |
| if samples_fake is not None: | |
| save_images(samples_fake, batch['id'], self.outdir_fake, to256=False ) | |
| if self.config.to256: | |
| save_images(samples_fake, batch['id'], self.outdir_fake256, to256=True ) | |
| def fire_fid(self): | |
| paths = [self.outdir_real, self.outdir_fake] | |
| if self.config.to256: | |
| paths = [self.outdir_real256, self.outdir_fake256] | |