Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| import io | |
| import math | |
| import sys | |
| import tempfile | |
| from PIL import Image, ImageOps | |
| import requests | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms import functional as TF | |
| from tqdm.notebook import tqdm | |
| import numpy as np | |
| from math import log2, sqrt | |
| import argparse | |
| import pickle | |
| ################################### mask_fusion ###################################### | |
| from util.metrics_accumulator import MetricsAccumulator | |
| metrics_accumulator = MetricsAccumulator() | |
| from pathlib import Path | |
| from PIL import Image | |
| ################################### mask_fusion ###################################### | |
| import clip | |
| import lpips | |
| from torch.nn.functional import mse_loss | |
| ################################### CLIPseg ###################################### | |
| from torchvision import utils as vutils | |
| import cv2 | |
| ################################### CLIPseg ###################################### | |
| def str2bool(x): | |
| return x.lower() in ('true') | |
| USE_CPU = False | |
| device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu') | |
| def fetch(url_or_path): | |
| if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
| r = requests.get(url_or_path) | |
| r.raise_for_status() | |
| fd = io.BytesIO() | |
| fd.write(r.content) | |
| fd.seek(0) | |
| return fd | |
| return open(url_or_path, 'rb') | |
| class MakeCutouts(nn.Module): | |
| def __init__(self, cut_size, cutn, cut_pow=1.): | |
| super().__init__() | |
| self.cut_size = cut_size | |
| self.cutn = cutn | |
| self.cut_pow = cut_pow | |
| def forward(self, input): | |
| sideY, sideX = input.shape[2:4] | |
| max_size = min(sideX, sideY) | |
| min_size = min(sideX, sideY, self.cut_size) | |
| cutouts = [] | |
| for _ in range(self.cutn): | |
| size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
| offsetx = torch.randint(0, sideX - size + 1, ()) | |
| offsety = torch.randint(0, sideY - size + 1, ()) | |
| cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
| cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
| return torch.cat(cutouts) | |
| def spherical_dist_loss(x, y): | |
| x = F.normalize(x, dim=-1) | |
| y = F.normalize(y, dim=-1) | |
| return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
| def do_run( | |
| arg_seed, arg_text, arg_batch_size, arg_num_batches, arg_negative, arg_cutn, arg_edit, arg_height, arg_width, | |
| arg_edit_y, arg_edit_x, arg_edit_width, arg_edit_height, mask, arg_guidance_scale, arg_background_preservation_loss, | |
| arg_lpips_sim_lambda, arg_l2_sim_lambda, arg_ddpm, arg_ddim, arg_enforce_background, arg_clip_guidance_scale, | |
| arg_clip_guidance, model_params, model, diffusion, ldm, bert, clip_model | |
| ): | |
| normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
| if arg_seed >= 0: | |
| torch.manual_seed(arg_seed) | |
| text_emb = bert.encode([arg_text] * arg_batch_size).to(device).float() | |
| text_blank = bert.encode([arg_negative] * arg_batch_size).to(device).float() | |
| text = clip.tokenize([arg_text] * arg_batch_size, truncate=True).to(device) | |
| text_clip_blank = clip.tokenize([arg_negative] * arg_batch_size, truncate=True).to(device) | |
| text_emb_clip = clip_model.encode_text(text) | |
| text_emb_clip_blank = clip_model.encode_text(text_clip_blank) | |
| make_cutouts = MakeCutouts(clip_model.visual.input_resolution, arg_cutn) | |
| text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True) | |
| image_embed = None | |
| if arg_edit: | |
| w = arg_edit_width if arg_edit_width else arg_width | |
| h = arg_edit_height if arg_edit_height else arg_height | |
| arg_edit = arg_edit.convert('RGB') | |
| input_image_pil = arg_edit | |
| init_image_pil = input_image_pil.resize((arg_height, arg_width), Image.Resampling.LANCZOS) | |
| input_image_pil = ImageOps.fit(input_image_pil, (w, h)) | |
| im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device) | |
| init_image = (TF.to_tensor(init_image_pil).to(device).unsqueeze(0).mul(2).sub(1)) | |
| im = 2*im-1 | |
| im = ldm.encode(im).sample() | |
| y = arg_edit_y//8 | |
| x = arg_edit_x//8 | |
| input_image = torch.zeros(1, 4, arg_height//8, arg_width//8, device=device) | |
| ycrop = y + im.shape[2] - input_image.shape[2] | |
| xcrop = x + im.shape[3] - input_image.shape[3] | |
| ycrop = ycrop if ycrop > 0 else 0 | |
| xcrop = xcrop if xcrop > 0 else 0 | |
| input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop] | |
| input_image_pil = ldm.decode(input_image) | |
| input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1)) | |
| input_image *= 0.18215 | |
| new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width//8, arg_height//8)) | |
| mask1 = (new_mask > 0.5) | |
| mask1 = mask1.float() | |
| input_image *= mask1 | |
| image_embed = torch.cat(arg_batch_size*2*[input_image], dim=0).float() | |
| elif model_params['image_condition']: | |
| # using inpaint model but no image is provided | |
| image_embed = torch.zeros(arg_batch_size*2, 4, arg_height//8, arg_width//8, device=device) | |
| kwargs = { | |
| "context": torch.cat([text_emb, text_blank], dim=0).float(), | |
| "clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None, | |
| "image_embed": image_embed | |
| } | |
| # Create a classifier-free guidance sampling function | |
| def model_fn(x_t, ts, **kwargs): | |
| half = x_t[: len(x_t) // 2] | |
| combined = torch.cat([half, half], dim=0) | |
| model_out = model(combined, ts, **kwargs) | |
| eps, rest = model_out[:, :3], model_out[:, 3:] | |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
| half_eps = uncond_eps + arg_guidance_scale * (cond_eps - uncond_eps) | |
| eps = torch.cat([half_eps, half_eps], dim=0) | |
| return torch.cat([eps, rest], dim=1) | |
| cur_t = None | |
| def postprocess_fn(out, t): | |
| if mask is not None: | |
| background_stage_t = diffusion.q_sample(init_image, t[0]) | |
| background_stage_t = torch.tile( | |
| background_stage_t, dims=(arg_batch_size, 1, 1, 1) | |
| ) | |
| out["sample"] = out["sample"] * mask + background_stage_t * (1 - mask) | |
| return out | |
| # if arg_ddpm: | |
| # sample_fn = diffusion.p_sample_loop_progressive | |
| # elif arg_ddim: | |
| # sample_fn = diffusion.ddim_sample_loop_progressive | |
| # else: | |
| sample_fn = diffusion.plms_sample_loop_progressive | |
| def save_sample(i, sample): | |
| out_ims = [] | |
| for k, image in enumerate(sample['pred_xstart'][:arg_batch_size]): | |
| image /= 0.18215 | |
| im = image.unsqueeze(0) | |
| out = ldm.decode(im) | |
| metrics_accumulator.print_average_metric() | |
| for b in range(arg_batch_size): | |
| pred_image = sample["pred_xstart"][b] | |
| if arg_enforce_background: | |
| new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width, arg_height)) | |
| pred_image = ( | |
| init_image[0] * new_mask[0] + out * (1 - new_mask[0]) | |
| ) | |
| pred_image_pil = TF.to_pil_image(pred_image.squeeze(0).add(1).div(2).clamp(0, 1)) | |
| out_ims.append(pred_image_pil) | |
| return out_ims | |
| all_saved_ims = [] | |
| for i in range(arg_num_batches): | |
| cur_t = diffusion.num_timesteps - 1 | |
| samples = sample_fn( | |
| model_fn, | |
| (arg_batch_size*2, 4, int(arg_height//8), int(arg_width//8)), | |
| clip_denoised=False, | |
| model_kwargs=kwargs, | |
| cond_fn=None, | |
| device=device, | |
| progress=True, | |
| ) | |
| for j, sample in enumerate(samples): | |
| cur_t -= 1 | |
| if j % 5 == 0 and j != diffusion.num_timesteps - 1: | |
| all_saved_ims += save_sample(i, sample) | |
| all_saved_ims += save_sample(i, sample) | |
| return all_saved_ims | |
| def run_model( | |
| segmodel, model, diffusion, ldm, bert, clip_model, model_params, | |
| from_text, instruction, negative_prompt, original_img, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda | |
| ): | |
| input_image = original_img | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.Resize((256, 256)), | |
| ]) | |
| img = transform(input_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| preds = segmodel(img.repeat(1,1,1,1), from_text)[0] | |
| mask = torch.sigmoid(preds[0][0]) | |
| image = (mask.detach().cpu().numpy() * 255).astype(np.uint8) # cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| ret, thresh = cv2.threshold(image, 100, 255, cv2.THRESH_TRUNC, image) | |
| timg = np.array(thresh) | |
| x, y = timg.shape | |
| for row in range(x): | |
| for col in range(y): | |
| if (timg[row][col]) == 100: | |
| timg[row][col] = 255 | |
| if (timg[row][col]) < 100: | |
| timg[row][col] = 0 | |
| fulltensor = torch.full_like(mask, fill_value=255) | |
| bgtensor = fulltensor-timg | |
| mask = bgtensor / 255.0 | |
| gc.collect() | |
| use_ddim = False | |
| use_ddpm = False | |
| all_saved_ims = do_run( | |
| seed, instruction, 1, 1, negative_prompt, cutn, input_image, 256, 256, | |
| 0, 0, 0, 0, mask, guidance_scale, True, | |
| 1000, l2_sim_lambda, use_ddpm, use_ddim, True, clip_guidance_scale, False, | |
| model_params, model, diffusion, ldm, bert, clip_model | |
| ) | |
| return all_saved_ims[-1] | |