import spaces from huggingface_hub import snapshot_download, hf_hub_download import os import subprocess import importlib, site from PIL import Image import uuid import shutil import time import cv2 from generate import generate, load_model import json import gradio as gr import sys BASE = os.path.dirname(os.path.abspath(__file__)) PREPROCESS_DIR = os.path.join(BASE, "wan", "modules", "animate", "preprocess") sys.path.append(PREPROCESS_DIR) # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) # Clear caches so importlib will pick up new modules importlib.invalidate_caches() def sh(cmd): subprocess.check_call(cmd, shell=True) try: print("Attempting to download and build sam2...") print("download sam") sam_dir = snapshot_download(repo_id="alexnasa/sam2") @spaces.GPU(duration=500) def install_sam(): os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .") print("install sam") install_sam() # tell Python to re-scan site-packages now that the egg-link exists import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() print("sam2 installed successfully.") except Exception as e: raise gr.Error("sam2 installation failed") import torch from preprocess_data import run as run_preprocess from preprocess_data import load_preprocess_models print(f"Torch version: {torch.__version__}") os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results" snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B") wan_animate = load_model(True) rc_mapping = { "Video → Ref Image" : False, "Video ← Ref Image" : True } def preprocess_video(input_video_path, duration, session_id=None): if session_id is None: session_id = uuid.uuid4().hex output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) process_video_path = os.path.join(output_dir, 'input_video.mp4') convert_video_to_30fps_and_clip(input_video_path, process_video_path, duration_s=duration) return process_video_path def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None): """ Extracts the audio track from a video file and saves it as a WAV file. Args: video_path (str): Path to the input video file. output_wav_path (str): Path to save the extracted WAV file. sample_rate (int, optional): Output sample rate (e.g., 16000). If None, keep the original. """ cmd = [ 'ffmpeg', '-i', video_path, # Input video '-vn', # Disable video '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format) '-ac', '1', # Mono channel (use '2' for stereo) '-y', # Overwrite output '-loglevel', 'error' # Cleaner output ] # Only add the sample rate option if explicitly specified if sample_rate is not None: cmd.extend(['-ar', str(sample_rate)]) cmd.append(output_wav_path) try: subprocess.run(cmd, check=True, capture_output=True, text=True) return True except subprocess.CalledProcessError as e: return False def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path): """ Combines a silent MP4 video with a WAV audio file into a single MP4 with sound. Args: video_path (str): Path to the silent video file. audio_path (str): Path to the WAV audio file. output_video_path (str): Path to save the output MP4 with audio. """ cmd = [ 'ffmpeg', '-i', video_path, # Input video '-i', audio_path, # Input audio '-c:v', 'copy', # Copy video without re-encoding '-c:a', 'aac', # Encode audio as AAC (MP4-compatible) '-shortest', # Stop when the shortest stream ends '-y', # Overwrite output '-loglevel', 'error', output_video_path ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") def convert_video_to_30fps_and_clip(input_video_path, output_video_path, duration_s=2, target_fps=30): # Decide target box depending on orientation *as ffmpeg sees it*. # We'll just compute both and let expressions pick the right one. # If you truly want different targets by orientation, keep your is_portrait() and set these constants accordingly. # Build a crop expression that: # - never exceeds the input size # - keeps values even (required by yuv420p) # - stays centered crop_w_expr = "floor(min(in_w\,1280)/2)*2" crop_h_expr = "floor(min(in_h\,720)/2)*2" crop_x_expr = f"floor((in_w - {crop_w_expr})/2/2)*2" crop_y_expr = f"floor((in_h - {crop_h_expr})/2/2)*2" vf = ( f"crop={crop_w_expr}:{crop_h_expr}:{crop_x_expr}:{crop_y_expr}," f"fps={target_fps}" ) cmd = [ "ffmpeg", "-nostdin", "-hide_banner", "-y", "-i", input_video_path, "-t", str(duration_s), # Do crop and fps in one -vf so they see the same frame geometry "-vf", vf, # Make sure the output has even dims and a standard pix_fmt "-pix_fmt", "yuv420p", "-c:v", "libx264", "-c:a", "aac", output_video_path, ] try: subprocess.run(cmd, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") def is_portrait(video_file): # Get video information cap = cv2.VideoCapture(video_file) if not cap.isOpened(): error_msg = "Cannot open video file" gr.Warning(error_msg) orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() return orig_width < orig_height def calculate_time_required(max_duration_s, rc_bool): frames_count = 30 * max_duration_s chunks = frames_count // 77 + 1 if rc_bool: pose2d_tracking_duration_s = 75 iteration_per_step_s = 12 else: pose2d_tracking_duration_s = 65 iteration_per_step_s = 12 time_required = pose2d_tracking_duration_s + iteration_per_step_s * 5 * chunks print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}') return time_required def update_time_required(max_duration_s, rc_str): rc_bool = rc_mapping[rc_str] duration_s = calculate_time_required(max_duration_s, rc_bool) duration_m = duration_s / 60 return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)") def get_duration(input_video, max_duration_s, edited_frame, rc_bool, session_id, progress): return calculate_time_required(max_duration_s, rc_bool) @spaces.GPU(duration=120) def _animate(input_video, max_duration_s, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),): if session_id is None: session_id = uuid.uuid4().hex output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) preprocess_dir = os.path.join(output_dir, "preprocess_dir") os.makedirs(preprocess_dir, exist_ok=True) output_video_path = os.path.join(output_dir, 'result.mp4') # --- Measure preprocess time --- start_preprocess = time.time() if is_portrait(input_video): w = 480 h = 832 else: w = 832 h = 480 tag_string = "retarget_flag" if rc_bool: tag_string = "replace_flag" # sh("python ./wan/modules/animate/preprocess/preprocess_data.py " # "--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint " # f"--video_path {input_video} " # f"--refer_path {edited_frame} " # f"--save_path {preprocess_dir} " # f"--resolution_area {w} {h} --{tag_string} " # ) preprocess_model = load_preprocess_models() run_preprocess(preprocess_model, input_video, edited_frame, preprocess_dir, w, h, tag_string) preprocess_time = time.time() - start_preprocess print(f"Preprocess took {preprocess_time:.2f} seconds") # --- Measure generate time --- start_generate = time.time() generate(wan_animate, preprocess_dir, output_video_path, rc_bool) generate_time = time.time() - start_generate print(f"Generate took {generate_time:.2f} seconds") # --- Optional total time --- total_time = preprocess_time + generate_time print(f"Total time: {total_time:.2f} seconds") return output_video_path def animate_scene(input_video, max_duration_s, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),): if not input_video: raise gr.Error("Please provide an video") if not edited_frame: raise gr.Error("Please provide an image") if session_id is None: session_id = uuid.uuid4().hex input_video = preprocess_video(input_video, max_duration_s, session_id) rc_bool = rc_mapping[rc_str] output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) input_audio_path = os.path.join(output_dir, 'input_audio.wav') audio_extracted = extract_audio_from_video_ffmpeg(input_video, input_audio_path) edited_frame_png = os.path.join(output_dir, 'edited_frame.png') edited_frame_img = Image.open(edited_frame) edited_frame_img.save(edited_frame_png) output_video_path = _animate(input_video, max_duration_s, edited_frame_png, rc_bool, session_id, progress) final_video_path = os.path.join(output_dir, 'final_result.mp4') preprocess_dir = os.path.join(output_dir, "preprocess_dir") pose_video = os.path.join(preprocess_dir, 'src_pose.mp4') if rc_bool: mask_video = os.path.join(preprocess_dir, 'src_mask.mp4') bg_video = os.path.join(preprocess_dir, 'src_bg.mp4') face_video = os.path.join(preprocess_dir, 'src_face.mp4') else: mask_video = os.path.join(preprocess_dir, 'src_pose.mp4') bg_video = os.path.join(preprocess_dir, 'src_pose.mp4') face_video = os.path.join(preprocess_dir, 'src_pose.mp4') if audio_extracted: combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path) else: final_video_path = output_video_path return final_video_path, pose_video, bg_video, mask_video, face_video css = """ #col-container { margin: 0 auto; max-width: 1600px; } #step-column { padding: 20px; border-radius: 8px; box-shadow: var(--card-shadow); margin: 10px; } #col-showcase { margin: 0 auto; max-width: 1100px; } .button-gradient { background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%; border: none; padding: 14px 28px; font-size: 16px; font-weight: bold; color: white; border-radius: 10px; cursor: pointer; transition: 0.3s ease-in-out; animation: 2s linear 0s infinite normal none running gradientAnimation; box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px; } .toggle-container { display: inline-flex; background-color: #ffd6ff; /* light pink background */ border-radius: 9999px; padding: 4px; position: relative; width: fit-content; font-family: sans-serif; } .toggle-container input[type="radio"] { display: none; } .toggle-container label { position: relative; z-index: 2; flex: 1; text-align: center; font-weight: 700; color: #4b2ab5; /* dark purple text for unselected */ padding: 6px 22px; border-radius: 9999px; cursor: pointer; transition: color 0.25s ease; } /* Moving highlight */ .toggle-highlight { position: absolute; top: 4px; left: 4px; width: calc(50% - 4px); height: calc(100% - 8px); background-color: #4b2ab5; /* dark purple background */ border-radius: 9999px; transition: transform 0.25s ease; z-index: 1; } /* When "True" is checked */ #true:checked ~ label[for="true"] { color: #ffd6ff; /* light pink text */ } /* When "False" is checked */ #false:checked ~ label[for="false"] { color: #ffd6ff; /* light pink text */ } /* Move highlight to right side when False is checked */ #false:checked ~ .toggle-highlight { transform: translateX(100%); } """ def start_session(request: gr.Request): return request.session_hash def cleanup(request: gr.Request): sid = request.session_hash if sid: d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) shutil.rmtree(d1, ignore_errors=True) with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) with gr.Column(elem_id="col-container"): with gr.Row(): gr.HTML( """

Wan2.2-Animate-14B

[Model]

HF Space By:

""" ) with gr.Row(): with gr.Column(elem_id="step-column"): gr.HTML("""
1. Upload a Video
""") input_video = gr.Video(label="Input Video", height=512) max_duration_slider = gr.Slider(2, 8, 2, step=2, label="Max Duration", visible=False) gr.Examples( examples=[ [ "./examples/thunder.mp4", ], [ "./examples/okay.mp4", ], [ "./examples/tell.mp4", ], [ "./examples/truth.mp4", ], [ "./examples/life.mp4", ], ], inputs=[input_video], cache_examples=False, ) with gr.Column(elem_id="step-column"): gr.HTML("""
2. Upload a Ref Image
""") edited_frame = gr.Image(label="Ref Image", type="filepath", height=512) replace_character_string = gr.Radio( ["Video → Ref Image", "Video ← Ref Image"], value="Video ← Ref Image", show_label=False ) gr.Examples( examples=[ [ "./examples/ali.png", ], [ "./examples/amber.png", ], [ "./examples/ella.png", ], [ "./examples/sydney.png", ], ], inputs=[edited_frame], cache_examples=False, ) with gr.Column(elem_id="step-column"): gr.HTML("""
3. Wan Animate it!
""") output_video = gr.Video(label="Edited Video", height=512) time_required = gr.Text(value="⌚ Zero GPU Required: ~195.0s (3.2 mins)", show_label=False, visible=False) action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient") with gr.Accordion("Preprocessed Data", open=False, visible=True): with gr.Row(): pose_video = gr.Video(label="Pose Video") bg_video = gr.Video(label="Background Video") face_video = gr.Video(label="Face Video") mask_video = gr.Video(label="Mask Video") with gr.Row(): with gr.Column(elem_id="col-showcase"): gr.Examples( examples=[ [ "./examples/okay.mp4", 2, "./examples/amber.png", "Video ← Ref Image" ], [ "./examples/superman.mp4", 2, "./examples/superman.png", "Video ← Ref Image" ], [ "./examples/test_example.mp4", 2, "./examples/ella.png", "Video ← Ref Image" ], [ "./examples/paul.mp4", 2, "./examples/man.png", "Video → Ref Image" ], [ "./examples/desi.mp4", 2, "./examples/desi.png", "Video ← Ref Image" ], ], inputs=[input_video, max_duration_slider, edited_frame, replace_character_string], outputs=[output_video, pose_video, bg_video, mask_video, face_video], fn=animate_scene, cache_examples=True, ) action_button.click(fn=animate_scene, inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video]) max_duration_slider.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required]) replace_character_string.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required]) if __name__ == "__main__": demo.queue() demo.unload(cleanup) demo.launch(ssr_mode=False, share=True)