# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import logging import math import os import random import sys import types from contextlib import contextmanager from copy import deepcopy from functools import partial import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF from decord import VideoReader from PIL import Image from safetensors import safe_open from torchvision import transforms from tqdm import tqdm from .distributed.fsdp import shard_model from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward from .distributed.util import get_world_size from .modules.s2v.audio_encoder import AudioEncoder from .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v from .modules.t5 import T5EncoderModel from .modules.vae2_1 import Wan2_1_VAE from .utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors class WanS2V: def __init__( self, config, checkpoint_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_sp=False, t5_cpu=False, init_on_cpu=True, convert_model_dtype=False, ): r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints device_id (`int`, *optional*, defaults to 0): Id of target GPU device rank (`int`, *optional*, defaults to 0): Process rank for distributed training t5_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for T5 model dit_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for DiT model use_sp (`bool`, *optional*, defaults to False): Enable distribution strategy of sequence parallel. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. convert_model_dtype (`bool`, *optional*, defaults to False): Convert DiT model parameters dtype to 'config.param_dtype'. Only works without FSDP. """ self.device = torch.device(f"cuda:{device_id}") self.config = config self.rank = rank self.t5_cpu = t5_cpu self.init_on_cpu = init_on_cpu self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype if t5_fsdp or dit_fsdp or use_sp: self.init_on_cpu = False shard_fn = partial(shard_model, device_id=device_id) self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, ) self.vae = Wan2_1_VAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) logging.info(f"Creating WanModel from {checkpoint_dir}") if not dit_fsdp: self.noise_model = WanModel_S2V.from_pretrained( checkpoint_dir, torch_dtype=self.param_dtype, device_map=self.device) else: self.noise_model = WanModel_S2V.from_pretrained( checkpoint_dir, torch_dtype=self.param_dtype) self.noise_model = self._configure_model( model=self.noise_model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, convert_model_dtype=convert_model_dtype) self.audio_encoder = AudioEncoder( model_id=os.path.join(checkpoint_dir, "wav2vec2-large-xlsr-53-english")) if use_sp: self.sp_size = get_world_size() else: self.sp_size = 1 self.sample_neg_prompt = config.sample_neg_prompt self.motion_frames = config.transformer.motion_frames self.drop_first_motion = config.drop_first_motion self.fps = config.sample_fps self.audio_sample_m = 0 def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype): """ Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. Args: model (torch.nn.Module): The model instance to configure. use_sp (`bool`): Enable distribution strategy of sequence parallel. dit_fsdp (`bool`): Enable FSDP sharding for DiT model. shard_fn (callable): The function to apply FSDP sharding. convert_model_dtype (`bool`): Convert DiT model parameters dtype to 'config.param_dtype'. Only works without FSDP. Returns: torch.nn.Module: The configured model. """ model.eval().requires_grad_(False) if use_sp: for block in model.blocks: block.self_attn.forward = types.MethodType( sp_attn_forward_s2v, block.self_attn) model.use_context_parallel = True if dist.is_initialized(): dist.barrier() if dit_fsdp: model = shard_fn(model) else: if convert_model_dtype: model.to(self.param_dtype) if not self.init_on_cpu: model.to(self.device) return model def get_size_less_than_area(self, height, width, target_area=1024 * 704, divisor=64): if height * width <= target_area: # If the original image area is already less than or equal to the target, # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. max_upper_area = target_area min_scale = 0.1 max_scale = 1.0 else: # Resize to fit within the target area and then pad to multiples of `divisor` max_upper_area = target_area # Maximum allowed total pixel count after padding d = divisor - 1 b = d * (height + width) a = height * width c = d**2 - max_upper_area # Calculate scale boundaries using quadratic equation min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / ( 2 * a) # Scale when maximum padding is applied max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area # Use binary search-like iteration to find this scale find_it = False for i in range(100): scale = max_scale - (max_scale - min_scale) * i / 100 new_height, new_width = int(height * scale), int(width * scale) # Pad to make dimensions divisible by 64 pad_height = (64 - new_height % 64) % 64 pad_width = (64 - new_width % 64) % 64 pad_top = pad_height // 2 pad_bottom = pad_height - pad_top pad_left = pad_width // 2 pad_right = pad_width - pad_left padded_height, padded_width = new_height + pad_height, new_width + pad_width if padded_height * padded_width <= max_upper_area: find_it = True break if find_it: return padded_height, padded_width else: # Fallback: calculate target dimensions based on aspect ratio and divisor alignment aspect_ratio = width / height target_width = int( (target_area * aspect_ratio)**0.5 // divisor * divisor) target_height = int( (target_area / aspect_ratio)**0.5 // divisor * divisor) # Ensure the result is not larger than the original resolution if target_width >= width or target_height >= height: target_width = int(width // divisor * divisor) target_height = int(height // divisor * divisor) return target_height, target_width def prepare_default_cond_input(self, map_shape=[3, 12, 64, 64], motion_frames=5, lat_motion_frames=2, enable_mano=False, enable_kp=False, enable_pose=False): default_value = [1.0, -1.0, -1.0] cond_enable = [enable_mano, enable_kp, enable_pose] cond = [] for d, c in zip(default_value, cond_enable): if c: map_value = torch.ones( map_shape, dtype=self.param_dtype, device=self.device) * d cond_lat = torch.cat([ map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1), map_value ], dim=2) cond_lat = torch.stack( self.vae.encode(cond_lat.to( self.param_dtype)))[:, :, lat_motion_frames:].to( self.param_dtype) cond.append(cond_lat) if len(cond) >= 1: cond = torch.cat(cond, dim=1) else: cond = None return cond def encode_audio(self, audio_path, infer_frames): z = self.audio_encoder.extract_audio_feat( audio_path, return_all_layers=True) audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m) audio_embed_bucket = audio_embed_bucket.to(self.device, self.param_dtype) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) elif len(audio_embed_bucket.shape) == 4: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) return audio_embed_bucket, num_repeat def read_last_n_frames(self, video_path, n_frames, target_fps=16, reverse=False): """ Read the last `n_frames` from a video at the specified frame rate. Parameters: video_path (str): Path to the video file. n_frames (int): Number of frames to read. target_fps (int, optional): Target sampling frame rate. Defaults to 16. reverse (bool, optional): Whether to read frames in reverse order. If True, reads the first `n_frames` instead of the last ones. Returns: np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. """ vr = VideoReader(video_path) original_fps = vr.get_avg_fps() total_frames = len(vr) interval = max(1, round(original_fps / target_fps)) required_span = (n_frames - 1) * interval start_frame = max(0, total_frames - required_span - 1) if not reverse else 0 sampled_indices = [] for i in range(n_frames): indice = start_frame + i * interval if indice >= total_frames: break else: sampled_indices.append(indice) return vr.get_batch(sampled_indices).asnumpy() def load_pose_cond(self, pose_video, num_repeat, infer_frames, size): HEIGHT, WIDTH = size if not pose_video is None: pose_seq = self.read_last_n_frames( pose_video, n_frames=infer_frames * num_repeat, target_fps=self.fps, reverse=True) resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) tensor_trans = transforms.ToTensor() cond_tensor = torch.from_numpy(pose_seq) cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute( 1, 0, 2, 3).unsqueeze(0) padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] cond_tensor = torch.cat([ cond_tensor, - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH]) ], dim=2) cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) else: cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] COND = [] for r in range(len(cond_tensors)): cond = cond_tensors[r] cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) cond_lat = torch.stack( self.vae.encode( cond.to(dtype=self.param_dtype, device=self.device)))[:, :, 1:].cpu() # for mem save COND.append(cond_lat) return COND def get_gen_size(self, size, max_area, ref_image_path, pre_video_path): if not size is None: HEIGHT, WIDTH = size else: if pre_video_path: ref_image = self.read_last_n_frames( pre_video_path, n_frames=1)[0] else: ref_image = np.array(Image.open(ref_image_path).convert('RGB')) HEIGHT, WIDTH = ref_image.shape[:2] HEIGHT, WIDTH = self.get_size_less_than_area( HEIGHT, WIDTH, target_area=max_area) return (HEIGHT, WIDTH) def generate( self, input_prompt, ref_image_path, audio_path, enable_tts, tts_prompt_audio, tts_prompt_text, tts_text, num_repeat=1, pose_video=None, max_area=720 * 1280, infer_frames=80, shift=5.0, sample_solver='unipc', sampling_steps=40, guide_scale=5.0, n_prompt="", seed=-1, offload_model=True, init_first_frame=False, ): r""" Generates video frames from input image and text prompt using diffusion process. Args: input_prompt (`str`): Text prompt for content generation. ref_image_path ('str'): Input image path audio_path ('str'): Audio for video driven num_repeat ('int'): Number of clips to generate; will be automatically adjusted based on the audio length pose_video ('str'): If provided, uses a sequence of poses to drive the generated video max_area (`int`, *optional*, defaults to 720*1280): Maximum pixel area for latent space calculation. Controls video resolution scaling infer_frames (`int`, *optional*, defaults to 80): How many frames to generate per clips. The number should be 4n shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. sample_solver (`str`, *optional*, defaults to 'unipc'): Solver used to sample the video. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): Classifier-free guidance scale. Controls prompt adherence vs. creativity. If tuple, the first guide_scale will be used for low noise model and the second guide_scale will be used for high noise model. n_prompt (`str`, *optional*, defaults to ""): Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` seed (`int`, *optional*, defaults to -1): Random seed for noise generation. If -1, use random seed offload_model (`bool`, *optional*, defaults to True): If True, offloads models to CPU during generation to save VRAM init_first_frame (`bool`, *optional*, defaults to False): Whether to use the reference image as the first frame (i.e., standard image-to-video generation) Returns: torch.Tensor: Generated video frames tensor. Dimensions: (C, N H, W) where: - C: Color channels (3 for RGB) - N: Number of frames (81) - H: Frame height (from max_area) - W: Frame width from max_area) """ # preprocess size = self.get_gen_size( size=None, max_area=max_area, ref_image_path=ref_image_path, pre_video_path=None) HEIGHT, WIDTH = size channel = 3 resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) tensor_trans = transforms.ToTensor() ref_image = None motion_latents = None if ref_image is None: ref_image = np.array(Image.open(ref_image_path).convert('RGB')) if motion_latents is None: motion_latents = torch.zeros( [1, channel, self.motion_frames, HEIGHT, WIDTH], dtype=self.param_dtype, device=self.device) # extract audio emb if enable_tts is True: audio_path = self.tts(tts_prompt_audio, tts_prompt_text, tts_text) audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames) if num_repeat is None or num_repeat > nr: num_repeat = nr lat_motion_frames = (self.motion_frames + 3) // 4 model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image))) ref_pixel_values = tensor_trans(model_pic) ref_pixel_values = ref_pixel_values.unsqueeze(1).unsqueeze( 0) * 2 - 1.0 # b c 1 h w ref_pixel_values = ref_pixel_values.to( dtype=self.vae.dtype, device=self.vae.device) ref_latents = torch.stack(self.vae.encode(ref_pixel_values)) # encode the motion latents videos_last_frames = motion_latents.detach() drop_first_motion = self.drop_first_motion if init_first_frame: drop_first_motion = False motion_latents[:, :, -6:] = ref_pixel_values motion_latents = torch.stack(self.vae.encode(motion_latents)) # get pose cond input if need COND = self.load_pose_cond( pose_video=pose_video, num_repeat=num_repeat, infer_frames=infer_frames, size=size) seed = seed if seed >= 0 else random.randint(0, sys.maxsize) if n_prompt == "": n_prompt = self.sample_neg_prompt # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: context = self.text_encoder([input_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] out = [] # evaluation mode with ( torch.amp.autocast('cuda', dtype=self.param_dtype), torch.no_grad(), ): for r in range(num_repeat): seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed + r) lat_target_frames = (infer_frames + 3 + self.motion_frames ) // 4 - lat_motion_frames target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8] noise = [ torch.randn( 16, target_shape[0], target_shape[1], target_shape[2], dtype=self.param_dtype, device=self.device, generator=seed_g) ] max_seq_len = np.prod(target_shape) // 4 if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sample_scheduler.set_timesteps( sampling_steps, device=self.device, shift=shift) timesteps = sample_scheduler.timesteps elif sample_solver == 'dpm++': sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( sample_scheduler, device=self.device, sigmas=sampling_sigmas) else: raise NotImplementedError("Unsupported solver.") latents = deepcopy(noise) with torch.no_grad(): left_idx = r * infer_frames right_idx = r * infer_frames + infer_frames cond_latents = COND[r] if pose_video else COND[0] * 0 cond_latents = cond_latents.to( dtype=self.param_dtype, device=self.device) audio_input = audio_emb[..., left_idx:right_idx] input_motion_latents = motion_latents.clone() arg_c = { 'context': context[0:1], 'seq_len': max_seq_len, 'cond_states': cond_latents, "motion_latents": input_motion_latents, 'ref_latents': ref_latents, "audio_input": audio_input, "motion_frames": [self.motion_frames, lat_motion_frames], "drop_motion_frames": drop_first_motion and r == 0, } if guide_scale > 1: arg_null = { 'context': context_null[0:1], 'seq_len': max_seq_len, 'cond_states': cond_latents, "motion_latents": input_motion_latents, 'ref_latents': ref_latents, "audio_input": 0.0 * audio_input, "motion_frames": [ self.motion_frames, lat_motion_frames ], "drop_motion_frames": drop_first_motion and r == 0, } if offload_model or self.init_on_cpu: self.noise_model.to(self.device) torch.cuda.empty_cache() for i, t in enumerate(tqdm(timesteps)): latent_model_input = latents[0:1] timestep = [t] timestep = torch.stack(timestep).to(self.device) noise_pred_cond = self.noise_model( latent_model_input, t=timestep, **arg_c) if guide_scale > 1: noise_pred_uncond = self.noise_model( latent_model_input, t=timestep, **arg_null) noise_pred = [ u + guide_scale * (c - u) for c, u in zip(noise_pred_cond, noise_pred_uncond) ] else: noise_pred = noise_pred_cond temp_x0 = sample_scheduler.step( noise_pred[0].unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=seed_g)[0] latents[0] = temp_x0.squeeze(0) if offload_model: self.noise_model.cpu() torch.cuda.synchronize() torch.cuda.empty_cache() latents = torch.stack(latents) if not (drop_first_motion and r == 0): decode_latents = torch.cat([motion_latents, latents], dim=2) else: decode_latents = torch.cat([ref_latents, latents], dim=2) image = torch.stack(self.vae.decode(decode_latents)) image = image[:, :, -(infer_frames):] if (drop_first_motion and r == 0): image = image[:, :, 3:] overlap_frames_num = min(self.motion_frames, image.shape[2]) videos_last_frames = torch.cat([ videos_last_frames[:, :, overlap_frames_num:], image[:, :, -overlap_frames_num:] ], dim=2) videos_last_frames = videos_last_frames.to( dtype=motion_latents.dtype, device=motion_latents.device) motion_latents = torch.stack( self.vae.encode(videos_last_frames)) out.append(image.cpu()) videos = torch.cat(out, dim=2) del noise, latents del sample_scheduler if offload_model: gc.collect() torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() return videos[0] if self.rank == 0 else None def tts(self, tts_prompt_audio, tts_prompt_text, tts_text): if not hasattr(self, 'cosyvoice'): self.load_tts() speech_list = [] from cosyvoice.utils.file_utils import load_wav import torchaudio prompt_speech_16k = load_wav(tts_prompt_audio, 16000) if tts_prompt_text is not None: for i in self.cosyvoice.inference_zero_shot(tts_text, tts_prompt_text, prompt_speech_16k): speech_list.append(i['tts_speech']) else: for i in self.cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k): speech_list.append(i['tts_speech']) torchaudio.save('tts.wav', torch.concat(speech_list, dim=1), self.cosyvoice.sample_rate) return 'tts.wav' def load_tts(self): if not os.path.exists('CosyVoice'): from wan.utils.utils import download_cosyvoice_repo download_cosyvoice_repo('CosyVoice') if not os.path.exists('CosyVoice2-0.5B'): from wan.utils.utils import download_cosyvoice_model download_cosyvoice_model('CosyVoice2-0.5B', 'CosyVoice2-0.5B') sys.path.append('CosyVoice') sys.path.append('CosyVoice/third_party/Matcha-TTS') from cosyvoice.cli.cosyvoice import CosyVoice2 self.cosyvoice = CosyVoice2('CosyVoice2-0.5B')