Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import os | |
| import subprocess | |
| import importlib, site | |
| from PIL import Image | |
| import uuid | |
| import shutil | |
| import time | |
| import cv2 | |
| from generate import generate, load_model | |
| import json | |
| # Re-discover all .pth/.egg-link files | |
| for sitedir in site.getsitepackages(): | |
| site.addsitedir(sitedir) | |
| # Clear caches so importlib will pick up new modules | |
| importlib.invalidate_caches() | |
| def sh(cmd): subprocess.check_call(cmd, shell=True) | |
| try: | |
| print("Attempting to download and build sam2...") | |
| print("download sam") | |
| sam_dir = snapshot_download(repo_id="alexnasa/sam2") | |
| def install_sam(): | |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" | |
| sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .") | |
| print("install sam") | |
| install_sam() | |
| # tell Python to re-scan site-packages now that the egg-link exists | |
| import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() | |
| flash_attention_installed = True | |
| print("sam2 installed successfully.") | |
| except Exception as e: | |
| print(f"β οΈ Could not install sam2: {e}") | |
| print("Continuing without sam2...") | |
| import torch | |
| print(f"Torch version: {torch.__version__}") | |
| os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results" | |
| import gradio as gr | |
| snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B") | |
| wan_animate = load_model(True) | |
| rc_mapping = { | |
| "Video β Ref Image" : False, | |
| "Video β Ref Image" : True | |
| } | |
| def preprocess_video(input_video_path, session_id=None): | |
| if session_id is None: | |
| session_id = uuid.uuid4().hex | |
| output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| process_video_path = os.path.join(output_dir, 'input_video.mp4') | |
| convert_video_to_30fps_and_clip(input_video_path, process_video_path, crop_width=720, crop_height=1280) | |
| return process_video_path | |
| def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None): | |
| """ | |
| Extracts the audio track from a video file and saves it as a WAV file. | |
| Args: | |
| video_path (str): Path to the input video file. | |
| output_wav_path (str): Path to save the extracted WAV file. | |
| sample_rate (int, optional): Output sample rate (e.g., 16000). | |
| If None, keep the original. | |
| """ | |
| cmd = [ | |
| 'ffmpeg', | |
| '-i', video_path, # Input video | |
| '-vn', # Disable video | |
| '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format) | |
| '-ac', '1', # Mono channel (use '2' for stereo) | |
| '-y', # Overwrite output | |
| '-loglevel', 'error' # Cleaner output | |
| ] | |
| # Only add the sample rate option if explicitly specified | |
| if sample_rate is not None: | |
| cmd.extend(['-ar', str(sample_rate)]) | |
| cmd.append(output_wav_path) | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") | |
| def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path): | |
| """ | |
| Combines a silent MP4 video with a WAV audio file into a single MP4 with sound. | |
| Args: | |
| video_path (str): Path to the silent video file. | |
| audio_path (str): Path to the WAV audio file. | |
| output_video_path (str): Path to save the output MP4 with audio. | |
| """ | |
| cmd = [ | |
| 'ffmpeg', | |
| '-i', video_path, # Input video | |
| '-i', audio_path, # Input audio | |
| '-c:v', 'copy', # Copy video without re-encoding | |
| '-c:a', 'aac', # Encode audio as AAC (MP4-compatible) | |
| '-shortest', # Stop when the shortest stream ends | |
| '-y', # Overwrite output | |
| '-loglevel', 'error', | |
| output_video_path | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") | |
| def convert_video_to_30fps_and_clip( | |
| input_video_path, | |
| output_video_path, | |
| duration_s=2, | |
| target_fps=30, | |
| crop_width=None, | |
| crop_height=None | |
| ): | |
| # Get input video dimensions using ffprobe | |
| if crop_width and crop_height: | |
| probe_cmd = [ | |
| 'ffprobe', '-v', 'error', '-select_streams', 'v:0', | |
| '-show_entries', 'stream=width,height', | |
| '-of', 'json', input_video_path | |
| ] | |
| probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True) | |
| video_info = json.loads(probe_result.stdout) | |
| w = video_info['streams'][0]['width'] | |
| h = video_info['streams'][0]['height'] | |
| # Clamp crop size to not exceed actual dimensions | |
| crop_width = min(crop_width, w) | |
| crop_height = min(crop_height, h) | |
| # Center crop offsets | |
| crop_x = max((w - crop_width) // 2, 0) | |
| crop_y = max((h - crop_height) // 2, 0) | |
| crop_filter = f"crop={crop_width}:{crop_height}:{crop_x}:{crop_y}" | |
| else: | |
| crop_filter = None | |
| cmd = [ | |
| 'ffmpeg', | |
| '-i', input_video_path, | |
| '-r', str(target_fps), | |
| '-t', str(duration_s), | |
| ] | |
| if crop_filter: | |
| cmd += ['-vf', crop_filter] | |
| cmd += [ | |
| '-c:v', 'libx264', | |
| '-c:a', 'aac', | |
| '-strict', 'experimental', | |
| '-y', | |
| '-loglevel', 'error', | |
| output_video_path | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}") | |
| def get_frames_count(video_file): | |
| # Get video information | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| error_msg = "Cannot open video file" | |
| gr.Warning(error_msg) | |
| orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| return orig_frame_count | |
| def calculate_time_required(input_video, rc_bool): | |
| frames_count = get_frames_count(input_video) | |
| chunks = frames_count // 77 + 1 | |
| if rc_bool: | |
| pose2d_tracking_duration_s = 75 | |
| iteration_per_step_s = 13 | |
| else: | |
| pose2d_tracking_duration_s = 50 | |
| iteration_per_step_s = 12 | |
| time_required = pose2d_tracking_duration_s + iteration_per_step_s * 20 * chunks | |
| print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}') | |
| return time_required | |
| def update_time_required(input_video, rc_str): | |
| if input_video is None: | |
| return gr.update(value="β Zero GPU Required: --") | |
| rc_bool = rc_mapping[rc_str] | |
| duration_s = calculate_time_required(input_video, rc_bool) | |
| duration_m = duration_s / 60 | |
| return gr.update(value=f"β Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)") | |
| def get_duration(input_video, edited_frame, rc_bool, session_id, progress): | |
| return calculate_time_required(input_video, rc_bool) | |
| def _animate(input_video, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),): | |
| if session_id is None: | |
| session_id = uuid.uuid4().hex | |
| output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| preprocess_dir = os.path.join(output_dir, "preprocess_dir") | |
| os.makedirs(preprocess_dir, exist_ok=True) | |
| output_video_path = os.path.join(output_dir, 'result.mp4') | |
| # --- Measure preprocess time --- | |
| start_preprocess = time.time() | |
| # w = 720 | |
| # h = 480 | |
| # w = 720 | |
| # h = 1280 | |
| w = 480 | |
| h = 832 | |
| # w = 480 | |
| # h = 720 | |
| tag_string = "retarget_flag" | |
| if rc_bool: | |
| tag_string = "replace_flag" | |
| sh("python ./wan/modules/animate/preprocess/preprocess_data.py " | |
| "--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint " | |
| f"--video_path {input_video} " | |
| f"--refer_path {edited_frame} " | |
| f"--save_path {preprocess_dir} " | |
| f"--resolution_area {w} {h} --{tag_string} " | |
| ) | |
| preprocess_time = time.time() - start_preprocess | |
| print(f"Preprocess took {preprocess_time:.2f} seconds") | |
| # --- Measure generate time --- | |
| start_generate = time.time() | |
| generate(wan_animate, preprocess_dir, output_video_path, rc_bool) | |
| generate_time = time.time() - start_generate | |
| print(f"Generate took {generate_time:.2f} seconds") | |
| # --- Optional total time --- | |
| total_time = preprocess_time + generate_time | |
| print(f"Total time: {total_time:.2f} seconds") | |
| return output_video_path | |
| def animate_scene(input_video, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),): | |
| if not input_video: | |
| raise gr.Error("Please provide an video") | |
| if not edited_frame: | |
| raise gr.Error("Please provide an image") | |
| if session_id is None: | |
| session_id = uuid.uuid4().hex | |
| rc_bool = rc_mapping[rc_str] | |
| output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| input_audio_path = os.path.join(output_dir, 'input_audio.wav') | |
| extract_audio_from_video_ffmpeg(input_video, input_audio_path) | |
| output_video_path = _animate(input_video, edited_frame, rc_bool, session_id, progress) | |
| final_video_path = os.path.join(output_dir, 'final_result.mp4') | |
| preprocess_dir = os.path.join(output_dir, "preprocess_dir") | |
| pose_video = os.path.join(preprocess_dir, 'src_pose.mp4') | |
| if rc_bool: | |
| mask_video = os.path.join(preprocess_dir, 'src_mask.mp4') | |
| bg_video = os.path.join(preprocess_dir, 'src_bg.mp4') | |
| face_video = os.path.join(preprocess_dir, 'src_face.mp4') | |
| else: | |
| mask_video = os.path.join(preprocess_dir, 'src_pose.mp4') | |
| bg_video = os.path.join(preprocess_dir, 'src_pose.mp4') | |
| face_video = os.path.join(preprocess_dir, 'src_pose.mp4') | |
| combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path) | |
| return final_video_path, pose_video, bg_video, mask_video, face_video | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1600px; | |
| } | |
| #step-column { | |
| padding: 20px; | |
| border-radius: 8px; | |
| box-shadow: var(--card-shadow); | |
| margin: 10px; | |
| } | |
| #col-showcase { | |
| margin: 0 auto; | |
| max-width: 1100px; | |
| } | |
| .button-gradient { | |
| background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%; | |
| border: none; | |
| padding: 14px 28px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| color: white; | |
| border-radius: 10px; | |
| cursor: pointer; | |
| transition: 0.3s ease-in-out; | |
| animation: 2s linear 0s infinite normal none running gradientAnimation; | |
| box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px; | |
| } | |
| .toggle-container { | |
| display: inline-flex; | |
| background-color: #ffd6ff; /* light pink background */ | |
| border-radius: 9999px; | |
| padding: 4px; | |
| position: relative; | |
| width: fit-content; | |
| font-family: sans-serif; | |
| } | |
| .toggle-container input[type="radio"] { | |
| display: none; | |
| } | |
| .toggle-container label { | |
| position: relative; | |
| z-index: 2; | |
| flex: 1; | |
| text-align: center; | |
| font-weight: 700; | |
| color: #4b2ab5; /* dark purple text for unselected */ | |
| padding: 6px 22px; | |
| border-radius: 9999px; | |
| cursor: pointer; | |
| transition: color 0.25s ease; | |
| } | |
| /* Moving highlight */ | |
| .toggle-highlight { | |
| position: absolute; | |
| top: 4px; | |
| left: 4px; | |
| width: calc(50% - 4px); | |
| height: calc(100% - 8px); | |
| background-color: #4b2ab5; /* dark purple background */ | |
| border-radius: 9999px; | |
| transition: transform 0.25s ease; | |
| z-index: 1; | |
| } | |
| /* When "True" is checked */ | |
| #true:checked ~ label[for="true"] { | |
| color: #ffd6ff; /* light pink text */ | |
| } | |
| /* When "False" is checked */ | |
| #false:checked ~ label[for="false"] { | |
| color: #ffd6ff; /* light pink text */ | |
| } | |
| /* Move highlight to right side when False is checked */ | |
| #false:checked ~ .toggle-highlight { | |
| transform: translateX(100%); | |
| } | |
| """ | |
| def start_session(request: gr.Request): | |
| return request.session_hash | |
| def cleanup(request: gr.Request): | |
| sid = request.session_hash | |
| if sid: | |
| d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) | |
| shutil.rmtree(d1, ignore_errors=True) | |
| with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo: | |
| session_state = gr.State() | |
| demo.load(start_session, outputs=[session_state]) | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <p style="font-size:16px; display: inline; margin: 0;"> | |
| <strong>Wan2.2-Animate-14B </strong> | |
| </p> | |
| <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
| [Model] | |
| </a> | |
| <div style="text-align: center;"> | |
| <p style="font-size:16px; display: inline; margin: 0;"> | |
| HF Space By: | |
| </p> | |
| <a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
| <img src="https://img.shields.io/badge/π€-Follow Me-yellow.svg"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(elem_id="step-column"): | |
| gr.HTML(""" | |
| <div> | |
| <span style="font-size: 24px;">1. Upload a Video</span><br> | |
| </div> | |
| """) | |
| input_video = gr.Video(label="Input Video", height=512) | |
| with gr.Column(elem_id="step-column"): | |
| gr.HTML(""" | |
| <div> | |
| <span style="font-size: 24px;">2. Upload a Ref Image</span><br> | |
| </div> | |
| """) | |
| edited_frame = gr.Image(label="Ref Image", type="filepath", height=512) | |
| gr.HTML(""" | |
| <div> | |
| <span style="font-size: 24px;">3. Choose Mode</span><br> | |
| </div> | |
| """) | |
| replace_character_string = gr.Radio( | |
| ["Video β Ref Image", "Video β Ref Image"], value="Video β Ref Image", show_label=False | |
| ) | |
| with gr.Column(elem_id="step-column"): | |
| gr.HTML(""" | |
| <div> | |
| <span style="font-size: 24px;">4. Wan Animate it!</span><br> | |
| </div> | |
| """) | |
| output_video = gr.Video(label="Edited Video", height=512) | |
| time_required = gr.Text(value="β Zero GPU Required: --", show_label=False) | |
| action_button = gr.Button("Wan Animate π¦", variant='primary', elem_classes="button-gradient") | |
| with gr.Accordion("Preprocessed Data", open=False, visible=False): | |
| pose_video = gr.Video(label="Pose Video", height=512) | |
| bg_video = gr.Video(label="Background Video", height=512) | |
| face_video = gr.Video(label="Face Video", height=512) | |
| mask_video = gr.Video(label="Mask Video", height=512) | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-showcase"): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "./examples/desi.mp4", | |
| "./examples/desi.png", | |
| "Video β Ref Image" | |
| ], | |
| [ | |
| "./examples/paul.mp4", | |
| "./examples/man.png", | |
| "Video β Ref Image" | |
| ], | |
| ], | |
| inputs=[input_video, edited_frame, replace_character_string], | |
| outputs=[output_video, pose_video, bg_video, mask_video, face_video], | |
| fn=animate_scene, | |
| cache_examples=True, | |
| ) | |
| action_button.click(fn=animate_scene, inputs=[input_video, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video]) | |
| input_video.upload(preprocess_video, inputs=[input_video, session_state], outputs=[input_video]).then(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required]) | |
| replace_character_string.change(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.unload(cleanup) | |
| demo.launch(ssr_mode=False, share=True) | |