Spaces:
Runtime error
Runtime error
File size: 9,984 Bytes
3470339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
# ltx_manager_helpers.py
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
#
# ORIGINAL SOURCE: LTX-Video by Lightricks Ltd. & other open-source projects.
# Licensed under the Apache License, Version 2.0
# https://github.com/Lightricks/LTX-Video
#
# MODIFICATIONS FOR ADUC-SDR_Video:
# This file is part of ADUC-SDR_Video, a derivative work based on LTX-Video.
# It has been modified to manage pools of LTX workers, handle GPU memory,
# and prepare parameters for the ADUC-SDR orchestration framework.
# All modifications are also licensed under the Apache License, Version 2.0.
import torch
import gc
import os
import yaml
import logging
import huggingface_hub
import time
import threading
import json
from optimization import optimize_ltx_worker, can_optimize_fp8
from hardware_manager import hardware_manager
from inference import create_ltx_video_pipeline, calculate_padding
from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem
from ltx_video.models.autoencoders.vae_encode import vae_decode
logger = logging.getLogger(__name__)
class LtxWorker:
def __init__(self, device_id, ltx_config_file):
self.cpu_device = torch.device('cpu')
self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
with open(ltx_config_file, "r") as file:
self.config = yaml.safe_load(file)
self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
models_dir = "downloaded_models_gradio"
logger.info(f"LTX Worker ({self.device}): Carregando modelo para a CPU...")
model_path = os.path.join(models_dir, self.config["checkpoint_path"])
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(
repo_id="Lightricks/LTX-Video", filename=self.config["checkpoint_path"],
local_dir=models_dir, local_dir_use_symlinks=False
)
self.pipeline = create_ltx_video_pipeline(
ckpt_path=model_path, precision=self.config["precision"],
text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
sampler=self.config["sampler"], device='cpu'
)
logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
if self.device.type == 'cuda' and can_optimize_fp8():
logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
self.pipeline.to(self.device)
optimize_ltx_worker(self)
self.pipeline.to(self.cpu_device)
logger.info(f"LTX Worker ({self.device}): Otimização concluída. Modelo pronto.")
elif self.device.type == 'cuda':
logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada. Usando modelo padrão.")
def to_gpu(self):
if self.device.type == 'cpu': return
logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
self.pipeline.to(self.device)
def to_cpu(self):
if self.device.type == 'cpu': return
logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
self.pipeline.to('cpu')
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
def generate_video_fragment_internal(self, **kwargs):
return self.pipeline(**kwargs).images
class LtxPoolManager:
def __init__(self, device_ids, ltx_config_file):
logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
self.current_worker_index = 0
self.lock = threading.Lock()
self.last_cleanup_thread = None
def _cleanup_worker_thread(self, worker):
logger.info(f"LTX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
worker.to_cpu()
def _prepare_and_log_params(self, worker_to_use, **kwargs):
target_device = worker_to_use.device
height, width = kwargs['height'], kwargs['width']
conditioning_data = kwargs.get('conditioning_items_data', [])
final_conditioning_items = []
# --- LOG ADICIONADO: Detalhes dos tensores de condicionamento ---
conditioning_log_details = []
for i, item in enumerate(conditioning_data):
if hasattr(item, 'latent_tensor'):
item.latent_tensor = item.latent_tensor.to(target_device)
final_conditioning_items.append(item)
conditioning_log_details.append(
f" - Item {i}: frame={item.media_frame_number}, strength={item.conditioning_strength:.2f}, shape={list(item.latent_tensor.shape)}"
)
first_pass_config = worker_to_use.config.get("first_pass", {})
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
padding_vals = calculate_padding(height, width, padded_h, padded_w)
pipeline_params = {
"height": padded_h, "width": padded_w,
"num_frames": kwargs['video_total_frames'], "frame_rate": kwargs['video_fps'],
"generator": torch.Generator(device=target_device).manual_seed(int(kwargs.get('seed', time.time())) + kwargs['current_fragment_index']),
"conditioning_items": final_conditioning_items,
"is_video": True, "vae_per_channel_normalize": True,
"decode_timestep": float(kwargs.get('decode_timestep', worker_to_use.config.get("decode_timestep", 0.05))),
"decode_noise_scale": float(kwargs.get('decode_noise_scale', worker_to_use.config.get("decode_noise_scale", 0.025))),
"image_cond_noise_scale": float(kwargs.get('image_cond_noise_scale', 0.0)),
"stochastic_sampling": bool(kwargs.get('stochastic_sampling', worker_to_use.config.get("stochastic_sampling", False))),
"prompt": kwargs['motion_prompt'],
"negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality, artifacts"),
"guidance_scale": float(kwargs.get('guidance_scale', 1.0)),
"stg_scale": float(kwargs.get('stg_scale', 0.0)),
"rescaling_scale": float(kwargs.get('rescaling_scale', 1.0)),
}
if worker_to_use.is_distilled:
pipeline_params["timesteps"] = first_pass_config.get("timesteps")
pipeline_params["num_inference_steps"] = len(pipeline_params["timesteps"]) if "timesteps" in first_pass_config else 8
else:
pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 7))
# --- LOG ADICIONADO: Exibição completa dos parâmetros da pipeline ---
log_friendly_params = pipeline_params.copy()
log_friendly_params.pop('generator', None)
log_friendly_params.pop('conditioning_items', None)
logger.info("="*60)
logger.info(f"CHAMADA AO PIPELINE LTX NO DISPOSITIVO: {worker_to_use.device}")
logger.info(f"Modelo: {'Distilled' if worker_to_use.is_distilled else 'Base'}")
logger.info("-" * 20 + " PARÂMETROS DA PIPELINE " + "-" * 20)
logger.info(json.dumps(log_friendly_params, indent=2))
logger.info("-" * 20 + " ITENS DE CONDICIONAMENTO " + "-" * 19)
logger.info("\n".join(conditioning_log_details))
logger.info("="*60)
# --- FIM DO LOG ADICIONADO ---
return pipeline_params, padding_vals
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
worker_to_use = None
progress = kwargs.get('progress')
try:
with self.lock:
if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
self.last_cleanup_thread.join()
worker_to_use = self.workers[self.current_worker_index]
previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
worker_to_cleanup = self.workers[previous_worker_index]
cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
cleanup_thread.start()
self.last_cleanup_thread = cleanup_thread
worker_to_use.to_gpu()
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
pipeline_params['output_type'] = "latent"
if progress: progress(0.1, desc=f"[Especialista LTX em {worker_to_use.device}] Gerando latentes...")
with torch.no_grad():
result_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
return result_tensor, padding_vals
except Exception as e:
logger.error(f"LTX POOL MANAGER: Erro durante a geração de latentes: {e}", exc_info=True)
raise e
finally:
if worker_to_use:
logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
worker_to_use.to_cpu()
logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
with open("config.yaml", 'r') as f:
config = yaml.safe_load(f)
ltx_gpus_required = config['specialists']['ltx']['gpus_required']
ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
ltx_config_path = config['specialists']['ltx']['config_file']
ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file=ltx_config_path)
logger.info("Especialista de Vídeo (LTX) pronto.") |