Spaces:
Runtime error
Runtime error
| import argparse | |
| from PIL import Image, ImageDraw | |
| from evaluator import Evaluator | |
| from omegaconf import OmegaConf | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| import os | |
| from transformers import CLIPProcessor, CLIPModel | |
| from copy import deepcopy | |
| import torch | |
| from ldm.util import instantiate_from_config | |
| from trainer import read_official_ckpt, batch_to_device | |
| from evaluator import set_alpha_scale, save_images, draw_masks_from_boxes | |
| import numpy as np | |
| import clip | |
| from functools import partial | |
| import torchvision.transforms.functional as F | |
| import random | |
| device = "cuda" | |
| def alpha_generator(length, type=[1,0,0]): | |
| """ | |
| length is total timestpes needed for sampling. | |
| type should be a list containing three values which sum should be 1 | |
| It means the percentage of three stages: | |
| alpha=1 stage | |
| linear deacy stage | |
| alpha=0 stage. | |
| For example if length=100, type=[0.8,0.1,0.1] | |
| then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps, | |
| and the last 100 stpes are 0. | |
| """ | |
| assert len(type)==3 | |
| assert type[0] + type[1] + type[2] == 1 | |
| stage0_length = int(type[0]*length) | |
| stage1_length = int(type[1]*length) | |
| stage2_length = length - stage0_length - stage1_length | |
| if stage1_length != 0: | |
| decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1] | |
| decay_alphas = list(decay_alphas) | |
| else: | |
| decay_alphas = [] | |
| alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length | |
| assert len(alphas) == length | |
| return alphas | |
| def draw_box(img, locations): | |
| colors = ["red", "green", "blue", "olive", "orange", "brown", "cyan", "purple"] | |
| draw = ImageDraw.Draw(img) | |
| WW,HH = img.size | |
| for bid, box in enumerate(locations): | |
| draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5) | |
| return img | |
| def load_common_ckpt(config, common_ckpt): | |
| autoencoder = instantiate_from_config(config.autoencoder).to(device).eval() | |
| text_encoder = instantiate_from_config(config.text_encoder).to(device).eval() | |
| diffusion = instantiate_from_config(config.diffusion).to(device) | |
| autoencoder.load_state_dict( common_ckpt["autoencoder"] ) | |
| text_encoder.load_state_dict( common_ckpt["text_encoder"] ) | |
| diffusion.load_state_dict( common_ckpt["diffusion"] ) | |
| return [autoencoder, text_encoder, diffusion] | |
| def load_ckpt(config, state_dict, common_instances): | |
| model = instantiate_from_config(config.model).to(device) | |
| model.load_state_dict(state_dict['model']) | |
| set_alpha_scale(model, config.alpha_scale) | |
| print("ckpt is loaded") | |
| return [model] + common_instances | |
| def project(x, projection_matrix): | |
| """ | |
| x (Batch*768) should be the penultimate feature of CLIP (before projection) | |
| projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer | |
| defined in CLIP (out_dim, in_dim), thus we need to apply transpose below. | |
| this function will return the CLIP feature (without normalziation) | |
| """ | |
| return [email protected](projection_matrix, 0, 1) | |
| def get_clip_feature(model, processor, input, is_image=False): | |
| feature_type = ['before','after_reproject'] # text feature, image feature | |
| if is_image: | |
| image = input #Image.open(input).convert("RGB") | |
| inputs = processor(images=[image], return_tensors="pt", padding=True) | |
| inputs['pixel_values'] = inputs['pixel_values'].cuda() # we use our own preprocessing without center_crop | |
| inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda() # placeholder | |
| outputs = model(**inputs) | |
| feature = outputs.image_embeds | |
| if feature_type[1] == 'after_renorm': | |
| feature = feature*28.7 | |
| if feature_type[1] == 'after_reproject': | |
| feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0) | |
| feature = ( feature / feature.norm() ) * 28.7 | |
| feature = feature.unsqueeze(0) | |
| else: | |
| inputs = processor(text=input, return_tensors="pt", padding=True) | |
| inputs['input_ids'] = inputs['input_ids'].cuda() | |
| inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder | |
| inputs['attention_mask'] = inputs['attention_mask'].cuda() | |
| outputs = model(**inputs) | |
| feature = outputs.text_embeds if feature_type[0] == 'after' else outputs.text_model_output.pooler_output | |
| return feature | |
| def complete_mask(has_mask, max_objs): | |
| mask = torch.ones(1,max_objs) | |
| if type(has_mask) == int or type(has_mask) == float: | |
| return mask * has_mask | |
| else: | |
| for idx, value in enumerate(has_mask): | |
| mask[0,idx] = value | |
| return mask | |
| def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None): | |
| # import pdb; pdb.set_trace() | |
| phrases = meta["phrases"] | |
| images = meta["images"] | |
| if clip_model is None: | |
| version = "openai/clip-vit-large-patch14" | |
| model = CLIPModel.from_pretrained(version).cuda() | |
| processor = CLIPProcessor.from_pretrained(version) | |
| else: | |
| version = "openai/clip-vit-large-patch14" | |
| assert clip_model['version'] == version | |
| model = clip_model['model'] | |
| processor = clip_model['processor'] | |
| boxes = torch.zeros(max_objs, 4) | |
| masks = torch.zeros(max_objs) | |
| text_embeddings = torch.zeros(max_objs, 768) | |
| image_embeddings = torch.zeros(max_objs, 768) | |
| text_features = [] | |
| image_features = [] | |
| for phrase, image in zip(phrases,images): | |
| text_features.append( get_clip_feature(model, processor, phrase, is_image=False) ) | |
| image_features.append( get_clip_feature(model, processor, image, is_image=True) ) | |
| if len(text_features) > 0: | |
| text_features = torch.cat(text_features, dim=0) | |
| image_features = torch.cat(image_features, dim=0) | |
| for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)): | |
| boxes[idx] = torch.tensor(box) | |
| masks[idx] = 1 | |
| text_embeddings[idx] = text_feature | |
| image_embeddings[idx] = image_feature | |
| out = { | |
| "boxes" : boxes.unsqueeze(0).repeat(batch,1,1), | |
| "masks" : masks.unsqueeze(0).repeat(batch,1), | |
| "text_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_text_mask"], max_objs ), | |
| "image_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_image_mask"], max_objs ), | |
| "text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1), | |
| "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1) | |
| } | |
| return batch_to_device(out, device) | |
| def remove_numbers(text): | |
| result = ''.join([char for char in text if not char.isdigit()]) | |
| return result | |
| def process_box_phrase(names, bboxes): | |
| d = {} | |
| for i, phrase in enumerate(names): | |
| phrase = phrase.replace('_',' ') | |
| list_noun = phrase.split(' ') | |
| for n in list_noun: | |
| n = remove_numbers(n) | |
| if not n in d.keys(): | |
| d.update({n:[np.array(bboxes[i])]}) | |
| else: | |
| d[n].append(np.array(bboxes[i])) | |
| return d | |
| def Pharse2idx_2(prompt, name_box): | |
| prompt = prompt.replace('.','') | |
| prompt = prompt.replace(',','') | |
| prompt_list = prompt.strip('.').split(' ') | |
| object_positions = [] | |
| bbox_to_self_att = [] | |
| for obj in name_box.keys(): | |
| obj_position = [] | |
| in_prompt = False | |
| for word in obj.split(' '): | |
| if word in prompt_list: | |
| obj_first_index = prompt_list.index(word) + 1 | |
| obj_position.append(obj_first_index) | |
| in_prompt = True | |
| elif word +'s' in prompt_list: | |
| obj_first_index = prompt_list.index(word+'s') + 1 | |
| obj_position.append(obj_first_index) | |
| in_prompt = True | |
| elif word +'es' in prompt_list: | |
| obj_first_index = prompt_list.index(word+'es') + 1 | |
| obj_position.append(obj_first_index) | |
| in_prompt = True | |
| if in_prompt : | |
| bbox_to_self_att.append(np.array(name_box[obj])) | |
| object_positions.append(obj_position) | |
| return object_positions, bbox_to_self_att | |
| # @torch.no_grad() | |
| def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): | |
| # -------------- prepare model and misc --------------- # | |
| model, autoencoder, text_encoder, diffusion = loaded_model_list | |
| batch_size = instruction["batch_size"] | |
| is_inpaint = True if "input_image" in instruction else False | |
| save_folder = os.path.join("create_samples", instruction["save_folder_name"]) | |
| # -------------- set seed if required --------------- # | |
| if instruction.get('fix_seed', False): | |
| random_seed = instruction['rand_seed'] | |
| random.seed(random_seed) | |
| np.random.seed(random_seed) | |
| torch.manual_seed(random_seed) | |
| # ------------- prepare input for the model ------------- # | |
| with torch.no_grad(): | |
| batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None)) | |
| context = text_encoder.encode( [instruction["prompt"]]*batch_size ) | |
| uc = text_encoder.encode( batch_size*[""] ) | |
| name_box = process_box_phrase(instruction['phrases'], instruction['locations']) | |
| position, box_att = Pharse2idx_2(instruction['prompt'],name_box ) | |
| input = dict(x = None, | |
| timesteps = None, | |
| context = context, | |
| boxes = batch['boxes'], | |
| masks = batch['masks'], | |
| text_masks = batch['text_masks'], | |
| image_masks = batch['image_masks'], | |
| text_embeddings = batch["text_embeddings"], | |
| image_embeddings = batch["image_embeddings"], | |
| boxes_att=box_att, | |
| object_position = position ) | |
| inpainting_mask = x0 = None # used for inpainting | |
| if is_inpaint: | |
| input_image = F.pil_to_tensor( instruction["input_image"] ) | |
| input_image = ( input_image.float().unsqueeze(0).cuda() / 255 - 0.5 ) / 0.5 | |
| x0 = autoencoder.encode( input_image ) | |
| if instruction["actual_mask"] is not None: | |
| inpainting_mask = instruction["actual_mask"][None, None].expand(batch['boxes'].shape[0], -1, -1, -1).cuda() | |
| else: | |
| actual_boxes = [instruction['inpainting_boxes_nodrop'] for _ in range(batch['boxes'].shape[0])] | |
| inpainting_mask = draw_masks_from_boxes(actual_boxes, (x0.shape[-2], x0.shape[-1]) ).cuda() | |
| masked_x0 = x0*inpainting_mask | |
| inpainting_extra_input = torch.cat([masked_x0,inpainting_mask], dim=1) | |
| input["inpainting_extra_input"] = inpainting_extra_input | |
| # ------------- prepare sampler ------------- # | |
| alpha_generator_func = partial(alpha_generator, type=instruction["alpha_type"]) | |
| if False: | |
| sampler = DDIMSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) | |
| steps = 250 | |
| else: | |
| sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) | |
| steps = 50 | |
| # ------------- run sampler ... ------------- # | |
| shape = (batch_size, model.in_channels, model.image_size, model.image_size) | |
| samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=instruction['guidance_scale'], mask=inpainting_mask, x0=x0) | |
| with torch.no_grad(): | |
| samples_fake = autoencoder.decode(samples_fake) | |
| # ------------- other logistics ------------- # | |
| sample_list = [] | |
| for sample in samples_fake: | |
| sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 | |
| sample = sample.cpu().numpy().transpose(1,2,0) * 255 | |
| sample = Image.fromarray(sample.astype(np.uint8)) | |
| sample_list.append(sample) | |
| return sample_list, None | |
| # if __name__ == "__main__": | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT") | |
| # parser.add_argument("--official_ckpt", type=str, default='../../../data/sd-v1-4.ckpt', help="") | |
| # parser.add_argument("--batch_size", type=int, default=10, help="This will overwrite the one in yaml.") | |
| # parser.add_argument("--no_plms", action='store_true') | |
| # parser.add_argument("--guidance_scale", type=float, default=5, help="") | |
| # parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model") | |
| # args = parser.parse_args() | |
| # assert "sd-v1-4.ckpt" in args.official_ckpt, "only support for stable-diffusion model" | |
| # grounded_generation(args) | |