Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import os.path as osp | |
| import random | |
| from argparse import ArgumentParser | |
| from datetime import datetime | |
| import gradio as gr | |
| import numpy as np | |
| import openxlab | |
| import torch | |
| from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler | |
| from omegaconf import OmegaConf | |
| from openxlab.model import download | |
| from PIL import Image | |
| from animatediff.pipelines import I2VPipeline | |
| from animatediff.utils.util import RANGE_LIST, save_videos_grid | |
| sample_idx = 0 | |
| scheduler_dict = { | |
| "DDIM": DDIMScheduler, | |
| "Euler": EulerDiscreteScheduler, | |
| "PNDM": PNDMScheduler, | |
| } | |
| css = """ | |
| .toolbutton { | |
| margin-buttom: 0em 0em 0em 0em; | |
| max-width: 2.5em; | |
| min-width: 2.5em !important; | |
| height: 2.5em; | |
| } | |
| """ | |
| parser = ArgumentParser() | |
| parser.add_argument('--config', type=str, default='example/config/base.yaml') | |
| parser.add_argument('--server-name', type=str, default='0.0.0.0') | |
| parser.add_argument('--port', type=int, default=7860) | |
| parser.add_argument('--share', action='store_true') | |
| parser.add_argument('--local-debug', action='store_true') | |
| parser.add_argument('--save-path', default='samples') | |
| args = parser.parse_args() | |
| LOCAL_DEBUG = args.local_debug | |
| BASE_CONFIG = 'example/config/base.yaml' | |
| STYLE_CONFIG_LIST = { | |
| 'anime': './example/openxlab/2-animation.yaml', | |
| } | |
| # download models | |
| PIA_PATH = './models/PIA' | |
| VAE_PATH = './models/VAE' | |
| DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA' | |
| if not LOCAL_DEBUG: | |
| CACHE_PATH = '/home/xlab-app-center/.cache/model' | |
| PIA_PATH = osp.join(CACHE_PATH, 'PIA') | |
| VAE_PATH = osp.join(CACHE_PATH, 'VAE') | |
| DreamBooth_LoRA_PATH = osp.join(CACHE_PATH, 'DreamBooth_LoRA') | |
| STABLE_DIFFUSION_PATH = osp.join(CACHE_PATH, 'StableDiffusion') | |
| IP_ADAPTER_PATH = osp.join(CACHE_PATH, 'IP_Adapter') | |
| os.makedirs(PIA_PATH, exist_ok=True) | |
| os.makedirs(VAE_PATH, exist_ok=True) | |
| os.makedirs(DreamBooth_LoRA_PATH, exist_ok=True) | |
| os.makedirs(STABLE_DIFFUSION_PATH, exist_ok=True) | |
| openxlab.login(os.environ['OPENXLAB_AK'], os.environ['OPENXLAB_SK']) | |
| download(model_repo='zhangyiming/PIA-pruned', model_name='PIA', output=PIA_PATH) | |
| download(model_repo='zhangyiming/Counterfeit-V3.0', | |
| model_name='Counterfeit-V3.0_fp32_pruned', output=DreamBooth_LoRA_PATH) | |
| download(model_repo='zhangyiming/kl-f8-anime2_VAE', | |
| model_name='kl-f8-anime2', output=VAE_PATH) | |
| # ip_adapter | |
| download(model_repo='zhangyiming/IP-Adapter', | |
| model_name='clip_encoder', output=osp.join(IP_ADAPTER_PATH, 'image_encoder')) | |
| download(model_repo='zhangyiming/IP-Adapter', | |
| model_name='config', output=osp.join(IP_ADAPTER_PATH, 'image_encoder')) | |
| download(model_repo='zhangyiming/IP-Adapter', | |
| model_name='ip_adapter_sd15', output=IP_ADAPTER_PATH) | |
| # unet | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet', | |
| model_name='unet', output=osp.join(STABLE_DIFFUSION_PATH, 'unet')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet', | |
| model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'unet')) | |
| # vae | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE', | |
| model_name='vae', output=osp.join(STABLE_DIFFUSION_PATH, 'vae')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE', | |
| model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'vae')) | |
| # text encoder | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod', | |
| model_name='text_encoder', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod', | |
| model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder')) | |
| # tokenizer | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
| model_name='merge', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
| model_name='special_tokens_map', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
| model_name='tokenizer_config', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
| download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
| model_name='vocab', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
| # scheduler | |
| scheduler_dict = { | |
| "_class_name": "PNDMScheduler", | |
| "_diffusers_version": "0.6.0", | |
| "beta_end": 0.012, | |
| "beta_schedule": "scaled_linear", | |
| "beta_start": 0.00085, | |
| "num_train_timesteps": 1000, | |
| "set_alpha_to_one": False, | |
| "skip_prk_steps": True, | |
| "steps_offset": 1, | |
| "trained_betas": None, | |
| "clip_sample": False | |
| } | |
| os.makedirs(osp.join(STABLE_DIFFUSION_PATH, 'scheduler'), exist_ok=True) | |
| with open(osp.join(STABLE_DIFFUSION_PATH, 'scheduler', 'scheduler_config.json'), 'w') as file: | |
| json.dump(scheduler_dict, file) | |
| # model index | |
| model_index_dict = { | |
| "_class_name": "StableDiffusionPipeline", | |
| "_diffusers_version": "0.6.0", | |
| "feature_extractor": [ | |
| "transformers", | |
| "CLIPImageProcessor" | |
| ], | |
| "safety_checker": [ | |
| "stable_diffusion", | |
| "StableDiffusionSafetyChecker" | |
| ], | |
| "scheduler": [ | |
| "diffusers", | |
| "PNDMScheduler" | |
| ], | |
| "text_encoder": [ | |
| "transformers", | |
| "CLIPTextModel" | |
| ], | |
| "tokenizer": [ | |
| "transformers", | |
| "CLIPTokenizer" | |
| ], | |
| "unet": [ | |
| "diffusers", | |
| "UNet2DConditionModel" | |
| ], | |
| "vae": [ | |
| "diffusers", | |
| "AutoencoderKL" | |
| ] | |
| } | |
| with open(osp.join(STABLE_DIFFUSION_PATH, 'model_index.json'), 'w') as file: | |
| json.dump(model_index_dict, file) | |
| else: | |
| PIA_PATH = './models/PIA' | |
| VAE_PATH = './models/VAE' | |
| DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA' | |
| STABLE_DIFFUSION_PATH = './models/StableDiffusion/sd15' | |
| def preprocess_img(img_np, max_size: int = 512): | |
| ori_image = Image.fromarray(img_np).convert('RGB') | |
| width, height = ori_image.size | |
| short_edge = max(width, height) | |
| if short_edge > max_size: | |
| scale_factor = max_size / short_edge | |
| else: | |
| scale_factor = 1 | |
| width = int(width * scale_factor) | |
| height = int(height * scale_factor) | |
| ori_image = ori_image.resize((width, height)) | |
| if (width % 8 != 0) or (height % 8 != 0): | |
| in_width = (width // 8) * 8 | |
| in_height = (height // 8) * 8 | |
| else: | |
| in_width = width | |
| in_height = height | |
| in_image = ori_image | |
| in_image = ori_image.resize((in_width, in_height)) | |
| in_image_np = np.array(in_image) | |
| return in_image_np, in_height, in_width | |
| class AnimateController: | |
| def __init__(self): | |
| # config dirs | |
| self.basedir = os.getcwd() | |
| self.savedir = os.path.join( | |
| self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
| self.savedir_sample = os.path.join(self.savedir, "sample") | |
| os.makedirs(self.savedir, exist_ok=True) | |
| self.inference_config = OmegaConf.load(args.config) | |
| self.style_configs = {k: OmegaConf.load( | |
| v) for k, v in STYLE_CONFIG_LIST.items()} | |
| self.pipeline_dict = self.load_model_list() | |
| def load_model_list(self): | |
| pipeline_dict = dict() | |
| for style, cfg in self.style_configs.items(): | |
| dreambooth_path = cfg.get('dreambooth', 'none') | |
| if dreambooth_path and dreambooth_path.upper() != 'NONE': | |
| dreambooth_path = osp.join( | |
| DreamBooth_LoRA_PATH, dreambooth_path) | |
| lora_path = cfg.get('lora', None) | |
| if lora_path is not None: | |
| lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path) | |
| lora_alpha = cfg.get('lora_alpha', 0.0) | |
| vae_path = cfg.get('vae', None) | |
| if vae_path is not None: | |
| vae_path = osp.join(VAE_PATH, vae_path) | |
| pipeline_dict[style] = I2VPipeline.build_pipeline( | |
| self.inference_config, | |
| STABLE_DIFFUSION_PATH, | |
| unet_path=osp.join(PIA_PATH, 'pia.ckpt'), | |
| dreambooth_path=dreambooth_path, | |
| lora_path=lora_path, | |
| lora_alpha=lora_alpha, | |
| vae_path=vae_path, | |
| ip_adapter_path='h94/IP-Adapter', | |
| ip_adapter_scale=0.1) | |
| return pipeline_dict | |
| def fetch_default_n_prompt(self, style: str): | |
| cfg = self.style_configs[style] | |
| n_prompt = cfg.get('n_prompt', '') | |
| ip_adapter_scale = cfg.get('real_ip_adapter_scale', 0) | |
| gr.Info('Set default negative prompt and ip_adapter_scale.') | |
| print('Set default negative prompt and ip_adapter_scale.') | |
| return n_prompt, ip_adapter_scale | |
| def animate( | |
| self, | |
| init_img, | |
| motion_scale, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| cfg_scale_slider, | |
| seed_textbox, | |
| ip_adapter_scale, | |
| style, | |
| progress=gr.Progress(), | |
| ): | |
| if seed_textbox != -1 and seed_textbox != "": | |
| torch.manual_seed(int(seed_textbox)) | |
| else: | |
| torch.seed() | |
| seed = torch.initial_seed() | |
| pipeline = self.pipeline_dict[style] | |
| init_img, h, w = preprocess_img(init_img) | |
| sample = pipeline( | |
| image=init_img, | |
| prompt=prompt_textbox, | |
| negative_prompt=negative_prompt_textbox, | |
| num_inference_steps=sample_step_slider, | |
| guidance_scale=cfg_scale_slider, | |
| width=w, | |
| height=h, | |
| video_length=16, | |
| mask_sim_template_idx=motion_scale - 1, | |
| ip_adapter_scale=ip_adapter_scale, | |
| progress_fn=progress, | |
| ).videos | |
| save_sample_path = os.path.join( | |
| self.savedir_sample, f"{sample_idx}.mp4") | |
| save_videos_grid(sample, save_sample_path) | |
| sample_config = { | |
| "prompt": prompt_textbox, | |
| "n_prompt": negative_prompt_textbox, | |
| "sampler": sampler_dropdown, | |
| "num_inference_steps": sample_step_slider, | |
| "guidance_scale": cfg_scale_slider, | |
| "width": w, | |
| "height": h, | |
| "seed": seed, | |
| "motion": motion_scale, | |
| } | |
| json_str = json.dumps(sample_config, indent=4) | |
| with open(os.path.join(self.savedir, "logs.json"), "a") as f: | |
| f.write(json_str) | |
| f.write("\n\n") | |
| return save_sample_path | |
| controller = AnimateController() | |
| def ui(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML( | |
| "<div align='center'><font size='7'> <img src=\"file/pia.png\" style=\"height: 72px;\"/ > Your Personalized Image Animator</font></div>" | |
| "<div align='center'><font size='7'>via Plug-and-Play Modules in Text-to-Image Models </font></div>" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "<div align='center'><font size='5'><a href='https://pi-animator.github.io/'>Project Page</a>  " # noqa | |
| "<a href='https://arxiv.org/abs/2312.13964/'>Paper</a>  " | |
| "<a href='https://github.com/open-mmlab/PIA'>Code</a>  " # noqa | |
| # "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click Here!</a> </font></div>" # noqa | |
| "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click here! </a></font></div>" # noqa | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| with gr.Row(): | |
| init_img = gr.Image(label='Input Image') | |
| style_dropdown = gr.Dropdown(label='Style', choices=list( | |
| STYLE_CONFIG_LIST.keys()), value=list(STYLE_CONFIG_LIST.keys())[0]) | |
| with gr.Row(): | |
| prompt_textbox = gr.Textbox(label="Prompt", lines=1) | |
| gift_button = gr.Button( | |
| value='🎁', elem_classes='toolbutton' | |
| ) | |
| def append_gift(prompt): | |
| rand = random.randint(0, 2) | |
| if rand == 1: | |
| prompt = prompt + 'wearing santa hats' | |
| elif rand == 2: | |
| prompt = prompt + 'lift a Christmas gift' | |
| else: | |
| prompt = prompt + 'in Christmas suit, lift a Christmas gift' | |
| gr.Info('Merry Christmas! Add magic to your prompt!') | |
| return prompt | |
| gift_button.click( | |
| fn=append_gift, | |
| inputs=[prompt_textbox], | |
| outputs=[prompt_textbox], | |
| ) | |
| prompt_textbox = gr.Textbox(label="Prompt", lines=1) | |
| motion_scale_silder = gr.Slider( | |
| label='Motion Scale (Larger value means larger motion but less identity consistency)', value=2, step=1, minimum=1, maximum=len(RANGE_LIST)) | |
| ip_adapter_scale = gr.Slider( | |
| label='IP-Apdater Scale', value=controller.fetch_default_n_prompt( | |
| list(STYLE_CONFIG_LIST.keys())[0])[1], minimum=0, maximum=1) | |
| with gr.Accordion('Advance Options', open=False): | |
| negative_prompt_textbox = gr.Textbox( | |
| value=controller.fetch_default_n_prompt( | |
| list(STYLE_CONFIG_LIST.keys())[0])[0], | |
| label="Negative prompt", lines=2) | |
| with gr.Row(): | |
| sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( | |
| scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
| sample_step_slider = gr.Slider( | |
| label="Sampling steps", value=20, minimum=10, maximum=100, step=1) | |
| cfg_scale_slider = gr.Slider( | |
| label="CFG Scale", value=7.5, minimum=0, maximum=20) | |
| with gr.Row(): | |
| seed_textbox = gr.Textbox(label="Seed", value=-1) | |
| seed_button = gr.Button( | |
| value="\U0001F3B2", elem_classes="toolbutton") | |
| seed_button.click( | |
| fn=lambda x: random.randint(1, 1e8), | |
| outputs=[seed_textbox], | |
| queue=False | |
| ) | |
| generate_button = gr.Button( | |
| value="Generate", variant='primary') | |
| result_video = gr.Video( | |
| label="Generated Animation", interactive=False) | |
| style_dropdown.change(fn=controller.fetch_default_n_prompt, | |
| inputs=[style_dropdown], | |
| outputs=[negative_prompt_textbox, ip_adapter_scale], queue=False) | |
| generate_button.click( | |
| fn=controller.animate, | |
| inputs=[ | |
| init_img, | |
| motion_scale_silder, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| cfg_scale_slider, | |
| seed_textbox, | |
| ip_adapter_scale, | |
| style_dropdown, | |
| ], | |
| outputs=[result_video] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.queue(max_size=10) | |
| demo.launch(server_name=args.server_name, | |
| server_port=args.port, share=args.share, | |
| max_threads=10, | |
| allowed_paths=['pia.png']) | |