Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from lib.farancia import IImage | |
| from PIL import Image | |
| from i2v_enhance import i2v_enhance_interface | |
| from dataloader.dataset_factory import SingleImageDatasetFactory | |
| from pytorch_lightning import Trainer, LightningDataModule, seed_everything | |
| import math | |
| from diffusion_trainer import streaming_svd as streaming_svd_model | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors | |
| from utils.loader import download_ckpt | |
| from functools import partial | |
| from dataloader.video_data_module import VideoDataModule | |
| from pathlib import Path | |
| from pytorch_lightning.cli import LightningCLI, LightningArgumentParser | |
| from pytorch_lightning import LightningModule | |
| import sys | |
| import os | |
| from copy import deepcopy | |
| from utils.aux import ensure_annotation_class | |
| from diffusers import FluxPipeline | |
| from typing import Union | |
| class CustomCLI(LightningCLI): | |
| def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: | |
| parser.add_argument("--image", type=Path, | |
| help="Path to the input image(s)") | |
| parser.add_argument("--output", type=Path, | |
| help="Path to the output folder") | |
| parser.add_argument("--num_frames", type=int, default=100, | |
| help="Number of frames to generate.") | |
| parser.add_argument("--out_fps", type=int, default=24, | |
| help="Framerate of the generated video.") | |
| parser.add_argument("--chunk_size", type=int, default=38, | |
| help="Chunk size used in randomized blending.") | |
| parser.add_argument("--overlap_size", type=int, default=12, | |
| help="Overlap size used in randomized blending.") | |
| parser.add_argument("--use_randomized_blending", action="store_true", | |
| help="Wether to use randomized blending.") | |
| parser.add_argument("--use_fp16", action="store_true", | |
| help="Wether to use float16 quantization.") | |
| parser.add_argument("--prompt", type=str, default = "") | |
| return parser | |
| class StreamingSVD(): | |
| def __init__(self, load_argv = True) -> None: | |
| call_fol = Path(os.getcwd()).resolve() | |
| code_fol = Path(__file__).resolve().parent | |
| code_fol = os.path.relpath(code_fol, call_fol) | |
| argv_backup = deepcopy(sys.argv) | |
| if "--use_fp16" in sys.argv: | |
| os.environ["STREAMING_USE_FP16"] = "True" | |
| sys.argv = [__file__] | |
| sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol)) | |
| cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={ | |
| "parser_mode": "omegaconf"}, save_config_callback=None) | |
| self.__init_models(cli) | |
| self.__init_fields(cli) | |
| sys.argv = argv_backup | |
| def __init_models(self, cli): | |
| model = cli.model | |
| trainer = cli.trainer | |
| path = download_ckpt( | |
| local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local, | |
| global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global | |
| ) | |
| if path.endswith(".safetensors"): | |
| ckpt = load_safetensors(path) | |
| else: | |
| ckpt = torch.load(path, map_location="cpu")["state_dict"] | |
| model.load_state_dict(ckpt) # load trained model | |
| trainer = cli.trainer | |
| data_module_loader = partial(VideoDataModule, workers=0) | |
| vfi = i2v_enhance_interface.vfi_init(model.vfi) | |
| enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init( | |
| model.i2v_enhance) | |
| enhance_pipeline.unet.enable_forward_chunking(chunk_size=1, dim=1) | |
| flux_pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | |
| flux_pipe.enable_model_cpu_offload() | |
| # store of objects | |
| model: streaming_svd_model | |
| data_module_loader: LightningDataModule | |
| trainer: Trainer | |
| self.model = model | |
| self.vfi = vfi | |
| self.data_module_loader = data_module_loader | |
| self.enhance_pipeline = enhance_pipeline | |
| self.enhance_generator = enhance_generator | |
| self.trainer = trainer | |
| self.flux_pipe = flux_pipe | |
| def __init_fields(self, cli): | |
| self.input_path = cli.config["image"] | |
| self.output_path = cli.config["output"] | |
| self.num_frames = cli.config["num_frames"] | |
| self.fps = cli.config["out_fps"] | |
| self.use_randomized_blending = cli.config["use_randomized_blending"] | |
| self.chunk_size = cli.config["chunk_size"] | |
| self.overlap_size = cli.config["overlap_size"] | |
| self.prompt = cli.config["prompt"] | |
| def __config_call(self, config_cmds, code_fol): | |
| cmds = [cmd for cmd in config_cmds if len(cmd) > 0] | |
| cmd_init = [] | |
| cmd_init.append(f"--config") | |
| cmd_init.append(f"{code_fol}/config.yaml") | |
| if "--use_fp16" in config_cmds: | |
| cmd_init.append(f"--trainer.precision=16-true") | |
| cmd_init.extend(cmds) | |
| return cmd_init | |
| # interfaces | |
| def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33): | |
| image = self.text_to_image(prompt=prompt) | |
| return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed) | |
| def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array: | |
| video, scaled_outpainted_image, expanded_size = self.image_to_video( | |
| image, num_frames=(num_frames+1)//2, seed=seed) | |
| max_memory_allocated = torch.cuda.max_memory_allocated() | |
| print( | |
| f"max_memory_allocated at image_to_video: {max_memory_allocated}") | |
| video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size, | |
| use_randomized_blending=use_randomized_blending, seed=seed) | |
| video = self.interpolate_video(video, dest_num_frames=num_frames) | |
| # scale/crop back to input size | |
| if image.shape[0] == 1: | |
| image = image[0] | |
| video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy() | |
| print( | |
| f"max_memory_allocated at interpolate_video: {max_memory_allocated}") | |
| return video | |
| # StreamingSVD pipeline | |
| def streaming(self, image: np.ndarray): | |
| datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory( | |
| file=image)) | |
| self.trainer.predict(model=self.model, datamodule=datamodule) | |
| video = self.trainer.generated_video | |
| expanded_size = self.trainer.expanded_size | |
| scaled_outpainted_image = self.trainer.scaled_outpainted_image | |
| return video, scaled_outpainted_image, expanded_size | |
| def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]: | |
| seed_everything(seed) | |
| if isinstance(image, str): | |
| image = IImage.open(image).numpy() | |
| if image.shape[0] == 1 and image.ndim == 4: | |
| image = image[0] | |
| assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3." | |
| assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8." | |
| # compute necessary number of chunks | |
| n_cond_frames = self.model.inference_params.num_conditional_frames | |
| n_frames_per_gen = self.model.sampler.guider.num_frames | |
| n_autoregressive_generations = math.ceil( | |
| (num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames)) | |
| self.model.inference_params.n_autoregressive_generations = int( | |
| n_autoregressive_generations) | |
| print(" --- STREAMING ----- [START]") | |
| video, scaled_outpainted_image, expanded_size = self.streaming( | |
| image=image) | |
| print(f" --- STREAMING ----- [FINISHED]: {video.shape}") | |
| video = video[:num_frames] | |
| return video, scaled_outpainted_image, expanded_size | |
| def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None): | |
| seed_everything(seed) | |
| if isinstance(video, str): | |
| video = IImage.open(video).numpy() | |
| if image is None: | |
| image = video[0] | |
| print("ATTENTION: We take first frame of previous stage as input frame for enhance. ") | |
| if num_frames is not None: | |
| video = video[:num_frames, ...] | |
| if not use_randomized_blending: | |
| chunk_size = video.shape[0] | |
| overlap_size = 0 | |
| if image.ndim == 3: | |
| image = image[None] | |
| image = [Image.fromarray( | |
| IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])] | |
| video = np.split(video, video.shape[0]) | |
| video = [Image.fromarray(frame[0]).resize((1280, 720)) | |
| for frame in video] | |
| print( | |
| f"---- ENHANCE ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.") | |
| video_enhanced = i2v_enhance_interface.i2v_enhance_process( | |
| image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator, | |
| chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending) | |
| video_enhanced = np.stack([np.asarray(frame) | |
| for frame in video_enhanced], axis=0) | |
| print("---- ENHANCE ---- [FINISHED].") | |
| return video_enhanced | |
| def interpolate_video(self, video: np.ndarray, dest_num_frames: int): | |
| video = np.split(video, len(video)) | |
| video = [frame[0] for frame in video] | |
| print(" ---- VFI ---- [START]") | |
| self.vfi.device() | |
| video_vfi = i2v_enhance_interface.vfi_process( | |
| video=video, vfi=self.vfi, video_len=dest_num_frames) | |
| video_vfi = np.stack([np.asarray(frame) | |
| for frame in video_vfi], axis=0) | |
| self.vfi.unload() | |
| print(f"---- VFI ---- [FINISHED]. Video length = {len(video_vfi)}") | |
| return video_vfi | |
| # T2I method | |
| def text_to_image(self, prompt, seed=33): | |
| # FLUX | |
| print("[FLUX] Generating image from text prompt") | |
| out = self.flux_pipe( | |
| prompt=prompt, | |
| guidance_scale=0, | |
| height=720, | |
| width=1280, | |
| num_inference_steps=4, | |
| max_sequence_length=256, | |
| generator=torch.Generator( | |
| device=self.model.device).manual_seed(seed), | |
| ).images[0] | |
| print("[FLUX] Finished") | |
| return np.array(out) | |
| if __name__ == "__main__": | |
| def get_input_data(input_path: Path = None): | |
| if input_path.is_file(): | |
| inputs = [input_path] | |
| else: | |
| suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]", | |
| "*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"] # loading png, jpg and bmp images | |
| inputs = [] | |
| for suffix in suffixes: | |
| inputs.extend(list(input_path.glob(suffix))) | |
| assert len( | |
| inputs) > 0, "No images found. Please make sure the input path is correct." | |
| img_as_np = [IImage.open(input).numpy() for input in inputs] | |
| return zip(img_as_np, inputs) | |
| streaming_svd = StreamingSVD() | |
| num_frames = streaming_svd.num_frames | |
| chunk_size = streaming_svd.chunk_size | |
| overlap_size = streaming_svd.overlap_size | |
| use_randomized_blending = streaming_svd.use_randomized_blending | |
| if not use_randomized_blending: | |
| chunk_size = (num_frames + 1)//2 | |
| overlap_size = 0 | |
| result_path = Path(streaming_svd.output_path) | |
| seed = 33 | |
| assert result_path.exists() is False or result_path.is_dir( | |
| ), "Output path must be the path to a folder." | |
| prompt = streaming_svd.prompt | |
| if len(prompt) == 0: | |
| for img, img_path in get_input_data(streaming_svd.input_path): | |
| video = streaming_svd.streaming_i2v( | |
| image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) | |
| if not result_path.exists(): | |
| result_path.mkdir(parents=True) | |
| result_file = result_path / (img_path.stem+".mp4") | |
| result_file = result_file.as_posix() | |
| IImage(video, vmin=0, vmax=255).setFps( | |
| streaming_svd.fps).save(result_file) | |
| print(f"Video created at: {result_file}") | |
| else: | |
| video = streaming_svd.streaming_t2v( | |
| prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) | |
| prompt_file = prompt.replace(" ", "_").replace( | |
| ".", "_").replace("/", "_").replace(":", "_") | |
| prompt_file = prompt_file[:15] | |
| if not result_path.exists(): | |
| result_path.mkdir(parents=True) | |
| result_file = result_path / (prompt_file+".mp4") | |
| result_file = result_file.as_posix() | |
| IImage(video, vmin=0, vmax=255).setFps( | |
| streaming_svd.fps).save(result_file) | |
| print(f"Video created at: {result_file}") | |