Spaces:
Sleeping
Sleeping
| import logging | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from fastcore.all import concat | |
| from huggingface_hub import notebook_login | |
| from PIL import Image | |
| import numpy as np | |
| # from IPython.display import display | |
| from torchvision import transforms as tfms | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers import LMSDiscreteScheduler | |
| from tqdm.auto import tqdm | |
| logging.disable(logging.WARNING) | |
| class ImageGenerator(): | |
| def __init__(self): | |
| self.latent_images = [] | |
| self.width = 512 | |
| self.height = 512 | |
| self.generator = torch.manual_seed(32) | |
| self.bs = 1 | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| self.float_size = torch.float16 | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| self.float_size = torch.float32 | |
| else: | |
| if not torch.backends.mps.is_built(): | |
| print("MPS not available because the current PyTorch install was not " | |
| "built with MPS enabled.") | |
| else: | |
| print("MPS not available because the current MacOS version is not 12.3+ " | |
| "and/or you do not have an MPS-enabled device on this machine.") | |
| self.device = torch.device("cpu") | |
| self.float_size = torch.float32 | |
| print(f"pytorch device: {self.device}") | |
| def __repr__(self): | |
| return f"Image Generator with {self.width=} {self.height=}" | |
| def load_models(self): | |
| self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=self.float_size) | |
| self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=self.float_size).to( self.device) | |
| # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16 ).to(self.device) | |
| self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to( self.device) | |
| self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet" ).to( self.device) #torch_dtype=torch.float16, | |
| def load_scheduler( self, | |
| beta_start : float=0.00085, | |
| beta_end : float=0.012, | |
| num_train_timesteps :int=1000): | |
| self.scheduler = LMSDiscreteScheduler( | |
| beta_start=beta_start, | |
| beta_end=beta_end, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=num_train_timesteps) | |
| def load_image(self, filepath:str): | |
| return Image.open(filepath).resize(size=(self.width,self.height)) | |
| #.convert("RGB") # RGB = 3 dimensions, RGBA = 4 dimensions | |
| def pil_to_latent(self, image: Image) -> torch.Tensor: | |
| with torch.no_grad(): | |
| image = image.resize(size=(self.width,self.height)) | |
| np_img = np.transpose( (( np.array(image) / 255)-0.5)*2, (2,0,1)) # turn pil image into np array with values between -1 and 1 | |
| # print(f"{np_img.shape=}") # 4, 64, 64 | |
| np_images = np.repeat(np_img[np.newaxis, :, :], self.bs, axis=0).astype(np.float32) # adding a new dimension and repeating the image for each prompt, float32 required for mac | |
| # print(f"{np_images.shape=}") | |
| decoded_latent = torch.from_numpy(np_images).to(self.device).float() #<-- stability-ai vae uses half(), compvis vae uses float? | |
| # print(f"{decoded_latent.shape=}") | |
| encoded_latent = 0.18215 * self.vae.encode(decoded_latent).latent_dist.sample() | |
| # print(f"{encoded_latent.shape=}") | |
| return encoded_latent | |
| def add_noise(self, latent: torch.Tensor, scheduler_steps: int = 10) -> torch.FloatTensor: | |
| # noise = torch.randn_like(latent) # missing generator parameter | |
| noise = torch.randn( | |
| size = (self.bs, self.unet.config.in_channels, self.height//8, self.width//8), | |
| generator = self.generator).to(self.device) | |
| timesteps = torch.tensor([self.scheduler.timesteps[scheduler_steps]]) | |
| noisy_latent = self.scheduler.add_noise(latent, noise, timesteps) | |
| # print(f"add_noise: {timesteps.shape=} {timesteps=} {noisy_latent.shape=}") | |
| return noisy_latent | |
| def latent_to_pil(self, latent:torch.Tensor) -> Image: | |
| # print(f"latent_to_pil {latent.dtype=}") | |
| with torch.no_grad(): | |
| decoded = self.vae.decode(1 / 0.18215 * latent).sample[0] | |
| # print(f"latent_to_pil {decoded.shape=}") | |
| image = (decoded/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy() | |
| return Image.fromarray((image*255).round().astype("uint8")) | |
| def image_grid(self, imgs: [Image]) -> Image: | |
| print(len(imgs)) | |
| w,h = imgs[0].size | |
| cols = len(imgs) | |
| grid = Image.new('RGB', size=(cols*w, h)) | |
| for i, img in enumerate(imgs): | |
| # print(f"{img.size=}") | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| return grid | |
| def text_enc(self, prompt:str, maxlen=None) -> torch.Tensor: | |
| '''tokenize and encode a prompt''' | |
| if maxlen is None: maxlen = self.tokenizer.model_max_length | |
| inp = self.tokenizer([prompt], padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt") | |
| return self.text_encoder(inp.input_ids.to(self.device))[0].float() | |
| def tensor_to_pil(self, t:torch.Tensor) -> Image: | |
| '''transforms a tensor decoded by the vae to a pil image''' | |
| # print(f"tensor_to_pil {t.shape=} {type(t)=}") | |
| image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy() | |
| return Image.fromarray((image*255).round().astype("uint8")) | |
| def latent_callback(self, latent:torch.Tensor) -> None: | |
| '''store latents in an array so that we can inpect them later.''' | |
| with torch.no_grad(): | |
| # print(f"cb {latent.shape=}") | |
| decoded = self.vae.decode(1 / 0.18215 * latent).sample[0] | |
| self.latent_images.append(self.tensor_to_pil(decoded)) | |
| def generate(self, | |
| prompt : str="", | |
| secondary_prompt: str=None, | |
| prompt_mix_ratio : float=0.5, | |
| negative_prompt="", | |
| seed : int=32, | |
| guidance :float=7.5, | |
| steps : int=30, | |
| start_step_ratio : float=1/5, | |
| init_image : Image=None, | |
| latent_callback_mod : int=10, | |
| progress_tqdm: callable=tqdm): | |
| self.latent_images = [] | |
| if not negative_prompt: negative_prompt = "" | |
| print(f"ImageGenerator: {prompt=} {secondary_prompt=} {prompt_mix_ratio=} {negative_prompt=} {guidance=} {steps=} {init_image=} ") | |
| with torch.no_grad(): | |
| text = self.text_enc(prompt) | |
| if secondary_prompt: | |
| print("using secondary prompt") | |
| sec_prompt_text = self.text_enc(secondary_prompt) | |
| text = text * prompt_mix_ratio + sec_prompt_text * ( 1 - prompt_mix_ratio ) | |
| uncond = self.text_enc(negative_prompt * self.bs, text.shape[1]) | |
| emb = torch.cat([uncond, text]) | |
| if seed: torch.manual_seed(seed) | |
| self.scheduler.set_timesteps(steps) | |
| self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32) | |
| if (init_image == None): | |
| start_steps = 0 | |
| latents = torch.randn( | |
| size = (self.bs, self.unet.config.in_channels, self.height//8, self.width//8), | |
| generator = self.generator) | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # print(f"{latents.shape=}") | |
| else: | |
| print("using base image") | |
| start_steps = int(steps * start_step_ratio) # 0%: too much noise, 100% no noise | |
| # print(f"{start_steps=}") | |
| latents =self.pil_to_latent(init_image) | |
| self.latent_callback(latents) | |
| latents = self.add_noise(latents, start_steps).to(self.device).float() | |
| self.latent_callback(latents) | |
| latents = latents.to(self.device).float() | |
| for i,ts in enumerate(progress_tqdm(self.scheduler.timesteps, desc="Latent Generation")): #leave=False, does not work with gradio | |
| if i >= start_steps: | |
| inp = self.scheduler.scale_model_input(torch.cat([latents] * 2), ts) | |
| with torch.no_grad(): | |
| u,t = self.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2) #todo, grab those with callbacks | |
| pred = u + guidance*(t-u) | |
| # pred = u + self.g*(t-u)/torch.norm(t-u)*torch.norm(u) | |
| latents = self.scheduler.step(pred, ts, latents).prev_sample | |
| if latent_callback_mod and i % latent_callback_mod == 0: | |
| self.latent_callback(latents) | |
| return self.latent_to_pil(latents), self.latent_images |