Spaces:
Runtime error
Runtime error
| # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py | |
| import argparse | |
| import os | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from animatediff.pipelines import I2VPipeline | |
| N_PROMPT = 'worst quality,low quality' | |
| GUIDANCE_SCALE = 7 | |
| BASE_CFG = './example/config/base.yaml' | |
| I2V_MODEL = './models/PIA/pia.ckpt' | |
| BASE_MODEL = './models/StableDiffusion/sd15' | |
| DREAMBOOTH_PATH = './models/DreamBooth_LoRA/Counterfeit-V3.0_fp32.safetensors' | |
| def post_process(videos: torch.Tensor): | |
| videos = rearrange(videos[0], "c t h w -> t h w c") | |
| videos = (videos * 255).clip(0, 255).cpu().numpy().astype(np.uint8) | |
| return videos | |
| def seed_everything(seed): | |
| import random | |
| import numpy as np | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed % (2**32)) | |
| random.seed(seed) | |
| def preprocess_img(img_path): | |
| ori_image = Image.open(img_path).convert('RGB') | |
| width, height = ori_image.size | |
| long_edge = max(width, height) | |
| if long_edge > 512: | |
| scale_factor = 512 / long_edge | |
| else: | |
| scale_factor = 1 | |
| width = int(width * scale_factor) | |
| height = int(height * scale_factor) | |
| ori_image = ori_image.resize((width, height)) | |
| if (width % 8 != 0) or (height % 8 != 0): | |
| in_width = (width // 8) * 8 | |
| in_height = (height // 8) * 8 | |
| else: | |
| in_width = width | |
| in_height = height | |
| in_image = ori_image | |
| in_image = ori_image.resize((in_width, in_height)) | |
| in_image_np = np.array(in_image) | |
| return in_image_np, in_height, in_width | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--img', type=str) | |
| parser.add_argument("--config", type=str) | |
| parser.add_argument('--prompt', type=str) | |
| parser.add_argument('--save-name', type=str) | |
| parser.add_argument('--motion', type=int, default=2) | |
| parser.add_argument('--ip-scale', type=float, default=0.3) | |
| parser.add_argument('--strength', type=float, default=1) | |
| args = parser.parse_args() | |
| # prepare paths and pipeline | |
| if args.config: | |
| config = OmegaConf.load(args.config) | |
| print('Load DreamBooth, LoRA and other things from config:') | |
| print(config) | |
| else: | |
| config = dict() | |
| base_model_path = BASE_MODEL | |
| unet_path = I2V_MODEL | |
| dreambooth_path = config.get('dreambooth', DREAMBOOTH_PATH) | |
| vae_path = config.get('vae', None) | |
| lora_path = config.get('lora', None) | |
| lora_alpha = config.get('lora_alpha', 0) | |
| only_load_vae_decoder = config.get('only_load_vae_decoder', False) | |
| only_load_vae_encoder = config.get('only_load_vae_encoder', False) | |
| st_motion = config.get('st_motion', None) | |
| base_cfg = OmegaConf.load(BASE_CFG) | |
| validation_pipeline = I2VPipeline.build_pipeline( | |
| base_cfg, | |
| base_model_path, | |
| unet_path, | |
| dreambooth_path, | |
| lora_path, | |
| lora_alpha, | |
| vae_path, | |
| ip_adapter_path='./models/IP_Adapter/', | |
| ip_adapter_scale=args.ip_scale, | |
| only_load_vae_decoder=only_load_vae_decoder, | |
| only_load_vae_encoder=only_load_vae_encoder) | |
| print(f'using unet : {unet_path}') | |
| print(f'using DreamBooth: {dreambooth_path}') | |
| print(f'using Lora : {lora_path}') | |
| validation_pipeline.set_st_motion(st_motion) | |
| print(f'Set Style Transfer Motion: {validation_pipeline.st_motion}.') | |
| # load image | |
| image_in, height, width = preprocess_img(args.img) | |
| if config.get('suffix', None): | |
| prompt = config.suffix + ',' + args.prompt | |
| else: | |
| prompt = args.prompt | |
| sample = validation_pipeline( | |
| image=image_in, | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| video_length=16, | |
| num_inference_steps=25, | |
| mask_sim_template_idx=args.motion, | |
| negative_prompt=config.get('n_prompt', N_PROMPT), | |
| guidance_scale=config.get('guidance_scale', GUIDANCE_SCALE), | |
| ip_adapter_scale=args.ip_scale, | |
| strength=args.strength | |
| ).videos | |
| save_name = args.save_name | |
| parent_name = os.path.dirname(save_name) | |
| if parent_name: | |
| os.makedirs(parent_name, exist_ok=True) | |
| imageio.mimsave(save_name, post_process(sample)) | |
| print(" <<< Test Done <<<") | |