Aduc-sdr-2_5s / app_ltx.py
euIaxs22's picture
Update app_ltx.py
9f03507 verified
raw
history blame
6.96 kB
import gradio as gr
import torch
import numpy as np
import random
import os
import yaml
from typing import Optional
from pathlib import Path
import imageio
import tempfile
from PIL import Image
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download
import shutil
import sys
from inference import (
create_ltx_video_pipeline,
create_latent_upsampler,
load_image_to_tensor_with_resize_and_crop,
seed_everething,
get_device,
calculate_padding,
load_media_file
)
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
config_file_path = APP_HOME / "configs/ltxv-13b-0.9.8-distilled-fp8.yaml"
with open(config_file_path, "r") as file:
PIPELINE_CONFIG_YAML = yaml.safe_load(file)
HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
models_dir = Path("/data/ltx_models")
LTX_REPO = "Lightricks/LTX-Video"
MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
MAX_NUM_FRAMES = 257
FPS = 30.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
# --- Global variables for loaded models ---
pipeline_instance = None
latent_upsampler_instance = None
Path(models_dir).mkdir(parents=True, exist_ok=True)
print("Downloading models (if not present)...")
distilled_model_actual_path = hf_hub_download(
repo_id=LTX_REPO,
filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
local_dir=models_dir,
#local_dir_use_symlinks=False,
cache_dir=HF_HOME_CACHE,
)
PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
print(f"Distilled model path: {distilled_model_actual_path}")
SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
spatial_upscaler_actual_path = hf_hub_download(
repo_id=LTX_REPO,
filename=SPATIAL_UPSCALER_FILENAME,
local_dir=models_dir,
#local_dir_use_symlinks=False,
cache_dir=HF_HOME_CACHE,
)
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
print("Creating LTX Video pipeline on CPU...")
pipeline_instance = create_ltx_video_pipeline(
ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
precision=PIPELINE_CONFIG_YAML["precision"],
text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
sampler=PIPELINE_CONFIG_YAML["sampler"],
device="cpu",
enhance_prompt=False,
prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
)
print("LTX Video pipeline created on CPU.")
if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
print("Creating latent upsampler on CPU...")
latent_upsampler_instance = create_latent_upsampler(
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
device="cpu"
)
print("Latent upsampler created on CPU.")
target_inference_device = "cuda"
print(f"Target inference device: {target_inference_device}")
pipeline_instance.to(target_inference_device)
if latent_upsampler_instance:
latent_upsampler_instance.to(target_inference_device)
# --- FUNÇÃO DE GERAÇÃO PRINCIPAL ---
def generate(
prompt: str, image_input: Optional[str],
target_height: int, target_width: int, num_frames: int, seed: int,
progress=gr.Progress(track_tqdm=True)
):
seed_everething(seed)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
height_padded = ((target_height - 1) // 32 + 1) * 32
width_padded = ((target_width - 1) // 32 + 1) * 32
padding_values = calculate_padding(target_height, target_width, height_padded, width_padded)
conditioning_items = None
if image_input:
progress(0.1, desc="Preparando imagem de condição...")
media_tensor = load_media_file(
media_path=image_input, height=target_height, width=target_width,
max_frames=1, padding=padding_values, just_crop=True
)
conditioning_items = [ConditioningItem(media_tensor.to(DEVICE, dtype=DTYPE), 0, 1.0)]
multi_scale_pipeline = LTXMultiScalePipeline(pipeline_instance, latent_upsampler_instance)
call_kwargs = {
"prompt": prompt, "negative_prompt": "worst quality...",
"height": target_height, "width": target_width, "num_frames": num_frames, "frame_rate": int(FPS),
"generator": generator, "output_type": "pt",
"conditioning_items": conditioning_items,
**PIPELINE_CONFIG_YAML
}
progress(0.3, desc="Gerando vídeo...")
result_tensor = multi_scale_pipeline(**call_kwargs).images
pad_left, pad_right, pad_top, pad_bottom = padding_values
slice_h_end = -pad_bottom if pad_bottom > 0 else None
slice_w_end = -pad_right if pad_right > 0 else None
result_tensor = result_tensor[:, :, :num_frames, pad_top:slice_h_end, pad_left:slice_w_end]
progress(0.9, desc="Exportando vídeo...")
output_video_path = tempfile.mktemp(suffix=".mp4")
video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
video_np = np.clip(video_np * 255, 0, 255).astype("uint8")
export_to_video(video_np, str(output_video_path), fps=24)
return output_video_path
# --- UI GRADIO ---
with gr.Blocks(title="LTX-Video (Final)", theme=gr.themes.Soft()) as demo:
gr.HTML("<h1>LTX-Video - Geração de Vídeo Multi-Scale (FP8)</h1>")
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="filepath", label="Imagem de Entrada (Opcional)")
prompt_in = gr.Textbox(label="Prompt", lines=4, placeholder="Ex: a cinematic shot...")
with gr.Accordion("Parâmetros", open=True):
height_in = gr.Slider(label="Altura", minimum=256, maximum=1024, step=32, value=480)
width_in = gr.Slider(label="Largura", minimum=256, maximum=1280, step=32, value=832)
frames_in = gr.Slider(label="Frames", minimum=17, maximum=161, step=8, value=97)
seed_in = gr.Number(label="Seed", value=42, precision=0)
run_button = gr.Button("Gerar Vídeo", variant="primary")
with gr.Column(scale=1):
video_out = gr.Video(label="Vídeo Gerado")
run_button.click(
fn=generate,
inputs=[prompt_in, image_in, height_in, width_in, frames_in, seed_in],
outputs=[video_out],
)
if __name__ == "__main__":
demo.queue().launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7861")),
show_error=True,
)