Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import random | |
| import os | |
| import yaml | |
| from typing import Optional | |
| from pathlib import Path | |
| import imageio | |
| import tempfile | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| import sys | |
| from inference import ( | |
| create_ltx_video_pipeline, | |
| create_latent_upsampler, | |
| load_image_to_tensor_with_resize_and_crop, | |
| seed_everething, | |
| get_device, | |
| calculate_padding, | |
| load_media_file | |
| ) | |
| from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline | |
| from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy | |
| APP_HOME = Path(os.environ.get("APP_HOME", "/app")) | |
| config_file_path = APP_HOME / "configs/ltxv-13b-0.9.8-distilled-fp8.yaml" | |
| with open(config_file_path, "r") as file: | |
| PIPELINE_CONFIG_YAML = yaml.safe_load(file) | |
| HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface")) | |
| models_dir = Path("/data/ltx_models") | |
| LTX_REPO = "Lightricks/LTX-Video" | |
| MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280) | |
| MAX_NUM_FRAMES = 257 | |
| FPS = 30.0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16 | |
| # --- Global variables for loaded models --- | |
| pipeline_instance = None | |
| latent_upsampler_instance = None | |
| Path(models_dir).mkdir(parents=True, exist_ok=True) | |
| print("Downloading models (if not present)...") | |
| distilled_model_actual_path = hf_hub_download( | |
| repo_id=LTX_REPO, | |
| filename=PIPELINE_CONFIG_YAML["checkpoint_path"], | |
| local_dir=models_dir, | |
| #local_dir_use_symlinks=False, | |
| cache_dir=HF_HOME_CACHE, | |
| ) | |
| PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path | |
| print(f"Distilled model path: {distilled_model_actual_path}") | |
| SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] | |
| spatial_upscaler_actual_path = hf_hub_download( | |
| repo_id=LTX_REPO, | |
| filename=SPATIAL_UPSCALER_FILENAME, | |
| local_dir=models_dir, | |
| #local_dir_use_symlinks=False, | |
| cache_dir=HF_HOME_CACHE, | |
| ) | |
| PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path | |
| print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}") | |
| print("Creating LTX Video pipeline on CPU...") | |
| pipeline_instance = create_ltx_video_pipeline( | |
| ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"], | |
| precision=PIPELINE_CONFIG_YAML["precision"], | |
| text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], | |
| sampler=PIPELINE_CONFIG_YAML["sampler"], | |
| device="cpu", | |
| enhance_prompt=False, | |
| prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"], | |
| prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"], | |
| ) | |
| print("LTX Video pipeline created on CPU.") | |
| if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"): | |
| print("Creating latent upsampler on CPU...") | |
| latent_upsampler_instance = create_latent_upsampler( | |
| PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"], | |
| device="cpu" | |
| ) | |
| print("Latent upsampler created on CPU.") | |
| target_inference_device = "cuda" | |
| print(f"Target inference device: {target_inference_device}") | |
| pipeline_instance.to(target_inference_device) | |
| if latent_upsampler_instance: | |
| latent_upsampler_instance.to(target_inference_device) | |
| # --- FUNÇÃO DE GERAÇÃO PRINCIPAL --- | |
| def generate( | |
| prompt: str, image_input: Optional[str], | |
| target_height: int, target_width: int, num_frames: int, seed: int, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| seed_everething(seed) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| height_padded = ((target_height - 1) // 32 + 1) * 32 | |
| width_padded = ((target_width - 1) // 32 + 1) * 32 | |
| padding_values = calculate_padding(target_height, target_width, height_padded, width_padded) | |
| conditioning_items = None | |
| if image_input: | |
| progress(0.1, desc="Preparando imagem de condição...") | |
| media_tensor = load_media_file( | |
| media_path=image_input, height=target_height, width=target_width, | |
| max_frames=1, padding=padding_values, just_crop=True | |
| ) | |
| conditioning_items = [ConditioningItem(media_tensor.to(DEVICE, dtype=DTYPE), 0, 1.0)] | |
| multi_scale_pipeline = LTXMultiScalePipeline(pipeline_instance, latent_upsampler_instance) | |
| call_kwargs = { | |
| "prompt": prompt, "negative_prompt": "worst quality...", | |
| "height": target_height, "width": target_width, "num_frames": num_frames, "frame_rate": int(FPS), | |
| "generator": generator, "output_type": "pt", | |
| "conditioning_items": conditioning_items, | |
| **PIPELINE_CONFIG_YAML | |
| } | |
| progress(0.3, desc="Gerando vídeo...") | |
| result_tensor = multi_scale_pipeline(**call_kwargs).images | |
| pad_left, pad_right, pad_top, pad_bottom = padding_values | |
| slice_h_end = -pad_bottom if pad_bottom > 0 else None | |
| slice_w_end = -pad_right if pad_right > 0 else None | |
| result_tensor = result_tensor[:, :, :num_frames, pad_top:slice_h_end, pad_left:slice_w_end] | |
| progress(0.9, desc="Exportando vídeo...") | |
| output_video_path = tempfile.mktemp(suffix=".mp4") | |
| video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() | |
| video_np = np.clip(video_np * 255, 0, 255).astype("uint8") | |
| export_to_video(video_np, str(output_video_path), fps=24) | |
| return output_video_path | |
| # --- UI GRADIO --- | |
| with gr.Blocks(title="LTX-Video (Final)", theme=gr.themes.Soft()) as demo: | |
| gr.HTML("<h1>LTX-Video - Geração de Vídeo Multi-Scale (FP8)</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_in = gr.Image(type="filepath", label="Imagem de Entrada (Opcional)") | |
| prompt_in = gr.Textbox(label="Prompt", lines=4, placeholder="Ex: a cinematic shot...") | |
| with gr.Accordion("Parâmetros", open=True): | |
| height_in = gr.Slider(label="Altura", minimum=256, maximum=1024, step=32, value=480) | |
| width_in = gr.Slider(label="Largura", minimum=256, maximum=1280, step=32, value=832) | |
| frames_in = gr.Slider(label="Frames", minimum=17, maximum=161, step=8, value=97) | |
| seed_in = gr.Number(label="Seed", value=42, precision=0) | |
| run_button = gr.Button("Gerar Vídeo", variant="primary") | |
| with gr.Column(scale=1): | |
| video_out = gr.Video(label="Vídeo Gerado") | |
| run_button.click( | |
| fn=generate, | |
| inputs=[prompt_in, image_in, height_in, width_in, frames_in, seed_in], | |
| outputs=[video_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| server_port=int(os.getenv("GRADIO_SERVER_PORT", "7861")), | |
| show_error=True, | |
| ) |