Spaces:
Paused
Paused
| # aduc_framework/managers/ltx_manager.py | |
| # | |
| # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos | |
| # | |
| # Versão 2.5.0 (Suporte a múltiplos LoRAs e inicialização robusta) | |
| # | |
| # Este manager é responsável por controlar a pipeline LTX-Video. Ele gerencia | |
| # um pool de workers para otimizar o uso de múltiplas GPUs, lida com a inicialização | |
| # e o setup de dependências complexas (clonando o repo LTX de forma segura), e expõe | |
| # uma interface de alto nível para a geração de vídeo, agora com suporte a uma lista | |
| # de LoRAs para composição avançada de estilos. | |
| import torch | |
| import gc | |
| import os | |
| import sys | |
| import yaml | |
| import logging | |
| import huggingface_hub | |
| import time | |
| import threading | |
| import subprocess | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple, Union | |
| # --- Imports Relativos Corrigidos --- | |
| from ..types import LatentConditioningItem | |
| from ..tools.optimization import optimize_ltx_worker, can_optimize_fp8 | |
| from ..tools.hardware_manager import hardware_manager | |
| logger = logging.getLogger(__name__) | |
| # --- Gerenciamento de Dependências e Placeholders --- | |
| DEPS_DIR = Path("./deps") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git" | |
| # Placeholders para módulos importados tardiamente (lazy-loaded) | |
| create_ltx_video_pipeline = None | |
| calculate_padding = None | |
| LTXVideoPipeline = None | |
| ConditioningItem = None | |
| LTXMultiScalePipeline = None | |
| vae_encode = None | |
| latent_to_pixel_coords = None | |
| randn_tensor = None | |
| class LtxPoolManager: | |
| """ | |
| Gerencia um pool de LtxWorkers e expõe a pipeline de aprimoramento de prompt. | |
| """ | |
| def __init__(self, device_ids: List[str], ltx_config_file_name: str): | |
| logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}") | |
| self._ltx_modules_loaded = False | |
| self._setup_dependencies() | |
| self._lazy_load_ltx_modules() | |
| self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name | |
| self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in device_ids] | |
| self.current_worker_index = 0 | |
| self.lock = threading.Lock() | |
| self.prompt_enhancement_pipeline = self.workers[0].pipeline if self.workers else None | |
| if self.prompt_enhancement_pipeline: | |
| logger.info("LTX POOL MANAGER: Pipeline de aprimoramento de prompt exposta para outros especialistas.") | |
| self._apply_ltx_pipeline_patches() | |
| if all(w.device.type == 'cuda' for w in self.workers): | |
| logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...") | |
| for worker in self.workers: | |
| worker.to_gpu() | |
| logger.info("LTX POOL MANAGER: Todas as GPUs estão prontas.") | |
| else: | |
| logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. Pré-aquecimento de GPU pulado.") | |
| def _setup_dependencies(self): | |
| """ | |
| Verifica a integridade do repositório LTX-Video e o clona usando um token de | |
| autenticação (se disponível), adicionando-o ao sys.path. | |
| """ | |
| check_file = LTX_VIDEO_REPO_DIR / "README.md" | |
| if not check_file.exists(): | |
| logger.warning(f"Repositório LTX-Video parece estar incompleto ou ausente. Tentando um clone limpo...") | |
| if LTX_VIDEO_REPO_DIR.exists(): | |
| try: | |
| shutil.rmtree(LTX_VIDEO_REPO_DIR) | |
| logger.info(f"Diretório corrompido '{LTX_VIDEO_REPO_DIR}' removido.") | |
| except OSError as e: | |
| logger.error(f"Não foi possível remover o diretório corrompido: {e}") | |
| raise RuntimeError("Falha ao limpar dependência corrompida.") | |
| # --- INÍCIO DA LÓGICA DE AUTENTICAÇÃO --- | |
| # Procura pelo token no ambiente (configurado via 'Secrets' no HF Spaces) | |
| auth_token = os.getenv("HF_TOKEN") | |
| repo_url = LTX_VIDEO_REPO_URL # URL padrão sem autenticação | |
| if auth_token: | |
| logger.info("Token de autenticação (HF_TOKEN) encontrado. Usando URL autenticada para o clone.") | |
| # Monta a URL no formato https://<token>@github.com/... | |
| repo_url = LTX_VIDEO_REPO_URL.replace("https://", f"https://{auth_token}@") | |
| else: | |
| logger.warning("Nenhum token de autenticação (HF_TOKEN) encontrado. O clone pode falhar se o repositório exigir login para LFS.") | |
| # --- FIM DA LÓGICA DE AUTENTICAÇÃO --- | |
| logger.info(f"Clonando repositório LTX-Video...") | |
| try: | |
| DEPS_DIR.mkdir(exist_ok=True) | |
| # Usa a `repo_url` que pode ou não conter o token | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", repo_url, str(LTX_VIDEO_REPO_DIR), "--quiet"], | |
| check=True, capture_output=True, text=True | |
| ) | |
| logger.info("Repositório LTX-Video clonado com sucesso.") | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"Falha CRÍTICA ao clonar o repositório LTX-Video. Git stderr: {e.stderr}") | |
| raise RuntimeError(f"Não foi possível clonar a dependência LTX-Video. Causa provável: {e.stderr}") | |
| else: | |
| logger.info("Repositório LTX-Video local encontrado e parece completo.") | |
| resolved_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if resolved_path not in sys.path: | |
| sys.path.insert(0, resolved_path) | |
| logger.info(f"Adicionado '{resolved_path}' ao sys.path.") | |
| def _lazy_load_ltx_modules(self): | |
| """Importa dinamicamente os módulos do LTX-Video após garantir que o repositório existe.""" | |
| if self._ltx_modules_loaded: | |
| return | |
| global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline | |
| global vae_encode, latent_to_pixel_coords, randn_tensor | |
| from .ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords | |
| from diffusers.utils.torch_utils import randn_tensor | |
| self._ltx_modules_loaded = True | |
| logger.info("Módulos do LTX-Video foram carregados dinamicamente.") | |
| def _apply_ltx_pipeline_patches(self): | |
| """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR.""" | |
| logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...") | |
| for worker in self.workers: | |
| worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline) | |
| logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.") | |
| def _get_next_worker(self) -> 'LtxWorker': | |
| with self.lock: | |
| worker = self.workers[self.current_worker_index] | |
| self.current_worker_index = (self.current_worker_index + 1) % len(self.workers) | |
| return worker | |
| def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict: | |
| pipeline_params = { | |
| "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'], | |
| "frame_rate": kwargs.get('video_fps', 24), | |
| "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)), | |
| "is_video": True, "vae_per_channel_normalize": True, | |
| "prompt": kwargs.get('motion_prompt', ""), "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality"), | |
| "guidance_scale": kwargs.get('guidance_scale', 1.0), "stg_scale": kwargs.get('stg_scale', 0.0), | |
| "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20), | |
| "output_type": "latent" | |
| } | |
| if 'latents' in kwargs: | |
| pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype) | |
| if 'strength' in kwargs: | |
| pipeline_params["strength"] = kwargs['strength'] | |
| if 'conditioning_items_data' in kwargs: | |
| final_conditioning_items = [] | |
| for item in kwargs['conditioning_items_data']: | |
| item.latent_tensor = item.latent_tensor.to(worker.device) | |
| final_conditioning_items.append(item) | |
| pipeline_params["conditioning_items"] = final_conditioning_items | |
| if worker.is_distilled: | |
| fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps") | |
| if fixed_timesteps: | |
| pipeline_params["timesteps"] = fixed_timesteps | |
| pipeline_params["num_inference_steps"] = len(fixed_timesteps) | |
| callback = kwargs.get('callback') | |
| if callback: | |
| pipeline_params["callback_on_step_end"] = callback | |
| pipeline_params["callback_on_step_end_tensor_inputs"] = ["latents"] | |
| return pipeline_params | |
| def generate_latent_fragment(self, **kwargs) -> Tuple[torch.Tensor, tuple]: | |
| worker_to_use = self._get_next_worker() | |
| try: | |
| height, width = kwargs['height'], kwargs['width'] | |
| padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32 | |
| padding_vals = calculate_padding(height, width, padded_h, padded_w) | |
| kwargs['height'], kwargs['width'] = padded_h, padded_w | |
| pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs) | |
| logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}") | |
| if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline): | |
| result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images | |
| else: | |
| result = worker_to_use.generate_video_fragment_internal(**pipeline_params) | |
| return result, padding_vals | |
| except Exception as e: | |
| logger.error(f"LTX POOL MANAGER: Erro durante a geração em {worker_to_use.device}: {e}", exc_info=True) | |
| raise e | |
| finally: | |
| if worker_to_use and worker_to_use.device.type == 'cuda': | |
| with torch.cuda.device(worker_to_use.device): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, tuple]: | |
| pass | |
| class LtxWorker: | |
| """Representa uma única instância da pipeline LTX-Video em um dispositivo específico.""" | |
| def __init__(self, device_id, ltx_config_file): | |
| self.cpu_device = torch.device('cpu') | |
| self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...") | |
| with open(ltx_config_file, "r") as file: | |
| self.config = yaml.safe_load(file) | |
| with open("config.yaml", 'r') as f: | |
| global_config = yaml.safe_load(f) | |
| lora_config = global_config.get('specialists', {}).get('ltx', {}).get('lora', {}) | |
| self.is_distilled = "distilled" in self.config.get("checkpoint_path", "") | |
| models_dir = LTX_VIDEO_REPO_DIR / "models_downloaded" | |
| logger.info(f"LTX Worker ({self.device}): Preparando para carregar modelo...") | |
| model_filename = self.config["checkpoint_path"] | |
| model_path = huggingface_hub.hf_hub_download( | |
| repo_id="Lightricks/LTX-Video", filename=model_filename, | |
| local_dir=str(models_dir), local_dir_use_symlinks=False | |
| ) | |
| self.pipeline = create_ltx_video_pipeline( | |
| ckpt_path=model_path, | |
| precision=self.config["precision"], | |
| text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"], | |
| sampler=self.config["sampler"], | |
| device='cpu' | |
| ) | |
| if lora_config and isinstance(lora_config, list): | |
| logger.info(f"LTX Worker ({self.device}): Encontrados {len(lora_config)} LoRAs para carregar.") | |
| for lora_item in lora_config: | |
| lora_model_id = lora_item.get("model_id") | |
| lora_weight = lora_item.get("weight", 0.7) | |
| if not lora_model_id: | |
| logger.warning(f"LTX Worker ({self.device}): Item de LoRA na lista sem 'model_id'. Pulando.") | |
| continue | |
| logger.info(f"LTX Worker ({self.device}): Carregando e aplicando LoRA '{lora_model_id}' com peso {lora_weight}...") | |
| try: | |
| self.pipeline.load_lora_weights(lora_model_id) | |
| self.pipeline.fuse_lora(lora_scale=lora_weight) | |
| logger.info(f"LTX Worker ({self.device}): LoRA '{lora_model_id}' fundido ao modelo com sucesso.") | |
| except Exception as e: | |
| logger.error(f"LTX Worker ({self.device}): Falha ao carregar ou fundir o LoRA '{lora_model_id}'. Erro: {e}", exc_info=True) | |
| logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo distilled? {self.is_distilled}") | |
| def to_gpu(self): | |
| if self.device.type == 'cpu': return | |
| logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...") | |
| self.pipeline.to(self.device) | |
| if self.device.type == 'cuda' and can_optimize_fp8(): | |
| logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Otimizando...") | |
| optimize_ltx_worker(self) | |
| logger.info(f"LTX Worker ({self.device}): Otimização completa.") | |
| def to_cpu(self): | |
| if self.device.type == 'cpu': return | |
| logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...") | |
| self.pipeline.to('cpu') | |
| gc.collect() | |
| if torch.cuda.is_available(): torch.cuda.empty_cache() | |
| def generate_video_fragment_internal(self, **kwargs): | |
| return self.pipeline(**kwargs).images | |
| def _aduc_prepare_conditioning_patch( | |
| self: "LTXVideoPipeline", | |
| conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]], | |
| init_latents: torch.Tensor, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| vae_per_channel_normalize: bool = False, | |
| generator=None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: | |
| if not conditioning_items: | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| return init_latents, init_pixel_coords, None, 0 | |
| init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], [] | |
| extra_conditioning_num_latents = 0 | |
| for item in conditioning_items: | |
| if not isinstance(item, LatentConditioningItem): | |
| logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.") | |
| continue | |
| media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device) | |
| media_frame_number, strength = item.media_frame_number, item.conditioning_strength | |
| if media_frame_number == 0: | |
| f_l, h_l, w_l = media_item_latents.shape[-3:] | |
| init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength) | |
| init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength | |
| else: | |
| noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype) | |
| media_item_latents = torch.lerp(noise, media_item_latents, strength) | |
| patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| pixel_coords[:, 0] += media_frame_number | |
| extra_conditioning_num_latents += patched_latents.shape[1] | |
| new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents.append(patched_latents) | |
| extra_conditioning_pixel_coords.append(pixel_coords) | |
| extra_conditioning_mask.append(new_mask) | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) | |
| init_conditioning_mask = init_conditioning_mask.squeeze(-1) | |
| if extra_conditioning_latents: | |
| init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) | |
| init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) | |
| init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) | |
| return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents | |
| # --- Instanciação Singleton --- | |
| with open("config.yaml", 'r') as f: | |
| config = yaml.safe_load(f) | |
| ltx_gpus_required = config['specialists']['ltx']['gpus_required'] | |
| ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required) | |
| ltx_config_filename = config['specialists']['ltx']['config_file'] | |
| ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file_name=ltx_config_filename) | |
| logger.info("Especialista de Vídeo (LTX) pronto.") |