File size: 18,327 Bytes
fb56537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# aduc_framework/managers/ltx_manager.py
#
# Copyright (C) August 4, 2025  Carlos Rodrigues dos Santos
#
# Versão 2.5.0 (Suporte a múltiplos LoRAs e inicialização robusta)
#
# Este manager é responsável por controlar a pipeline LTX-Video. Ele gerencia
# um pool de workers para otimizar o uso de múltiplas GPUs, lida com a inicialização
# e o setup de dependências complexas (clonando o repo LTX de forma segura), e expõe
# uma interface de alto nível para a geração de vídeo, agora com suporte a uma lista
# de LoRAs para composição avançada de estilos.

import torch
import gc
import os
import sys
import yaml
import logging
import huggingface_hub
import time
import threading
import subprocess
import shutil
from pathlib import Path
from typing import Optional, List, Tuple, Union

# --- Imports Relativos Corrigidos ---
from ..types import LatentConditioningItem
from ..tools.optimization import optimize_ltx_worker, can_optimize_fp8
from ..tools.hardware_manager import hardware_manager

logger = logging.getLogger(__name__)

# --- Gerenciamento de Dependências e Placeholders ---
DEPS_DIR = Path("./deps")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"

# Placeholders para módulos importados tardiamente (lazy-loaded)
create_ltx_video_pipeline = None
calculate_padding = None
LTXVideoPipeline = None
ConditioningItem = None
LTXMultiScalePipeline = None
vae_encode = None
latent_to_pixel_coords = None
randn_tensor = None

class LtxPoolManager:
    """
    Gerencia um pool de LtxWorkers e expõe a pipeline de aprimoramento de prompt.
    """
    def __init__(self, device_ids: List[str], ltx_config_file_name: str):
        logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
        self._ltx_modules_loaded = False
        self._setup_dependencies()
        self._lazy_load_ltx_modules()

        self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name

        self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in device_ids]
        self.current_worker_index = 0
        self.lock = threading.Lock()

        self.prompt_enhancement_pipeline = self.workers[0].pipeline if self.workers else None
        if self.prompt_enhancement_pipeline:
            logger.info("LTX POOL MANAGER: Pipeline de aprimoramento de prompt exposta para outros especialistas.")

        self._apply_ltx_pipeline_patches()

        if all(w.device.type == 'cuda' for w in self.workers):
            logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
            for worker in self.workers:
                worker.to_gpu()
            logger.info("LTX POOL MANAGER: Todas as GPUs estão prontas.")
        else:
            logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. Pré-aquecimento de GPU pulado.")


    def _setup_dependencies(self):
        """
        Verifica a integridade do repositório LTX-Video e o clona usando um token de
        autenticação (se disponível), adicionando-o ao sys.path.
        """
        check_file = LTX_VIDEO_REPO_DIR / "README.md"
        
        if not check_file.exists():
            logger.warning(f"Repositório LTX-Video parece estar incompleto ou ausente. Tentando um clone limpo...")
            
            if LTX_VIDEO_REPO_DIR.exists():
                try:
                    shutil.rmtree(LTX_VIDEO_REPO_DIR)
                    logger.info(f"Diretório corrompido '{LTX_VIDEO_REPO_DIR}' removido.")
                except OSError as e:
                    logger.error(f"Não foi possível remover o diretório corrompido: {e}")
                    raise RuntimeError("Falha ao limpar dependência corrompida.")

            # --- INÍCIO DA LÓGICA DE AUTENTICAÇÃO ---
            # Procura pelo token no ambiente (configurado via 'Secrets' no HF Spaces)
            auth_token = os.getenv("HF_TOKEN")
            repo_url = LTX_VIDEO_REPO_URL # URL padrão sem autenticação

            if auth_token:
                logger.info("Token de autenticação (HF_TOKEN) encontrado. Usando URL autenticada para o clone.")
                # Monta a URL no formato https://<token>@github.com/...
                repo_url = LTX_VIDEO_REPO_URL.replace("https://", f"https://{auth_token}@")
            else:
                logger.warning("Nenhum token de autenticação (HF_TOKEN) encontrado. O clone pode falhar se o repositório exigir login para LFS.")
            # --- FIM DA LÓGICA DE AUTENTICAÇÃO ---

            logger.info(f"Clonando repositório LTX-Video...")
            try:
                DEPS_DIR.mkdir(exist_ok=True)
                # Usa a `repo_url` que pode ou não conter o token
                subprocess.run(
                    ["git", "clone", "--depth", "1", repo_url, str(LTX_VIDEO_REPO_DIR), "--quiet"],
                    check=True, capture_output=True, text=True
                )
                logger.info("Repositório LTX-Video clonado com sucesso.")
            except subprocess.CalledProcessError as e:
                logger.error(f"Falha CRÍTICA ao clonar o repositório LTX-Video. Git stderr: {e.stderr}")
                raise RuntimeError(f"Não foi possível clonar a dependência LTX-Video. Causa provável: {e.stderr}")
        else:
            logger.info("Repositório LTX-Video local encontrado e parece completo.")

        resolved_path = str(LTX_VIDEO_REPO_DIR.resolve())
        if resolved_path not in sys.path:
            sys.path.insert(0, resolved_path)
            logger.info(f"Adicionado '{resolved_path}' ao sys.path.")

    
    def _lazy_load_ltx_modules(self):
        """Importa dinamicamente os módulos do LTX-Video após garantir que o repositório existe."""
        if self._ltx_modules_loaded:
            return

        global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
        global vae_encode, latent_to_pixel_coords, randn_tensor
        
        from .ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding
        from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
        from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
        from diffusers.utils.torch_utils import randn_tensor
        
        self._ltx_modules_loaded = True
        logger.info("Módulos do LTX-Video foram carregados dinamicamente.")

    def _apply_ltx_pipeline_patches(self):
        """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
        logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
        for worker in self.workers:
            worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
        logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")

    def _get_next_worker(self) -> 'LtxWorker':
        with self.lock:
            worker = self.workers[self.current_worker_index]
            self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
            return worker
    
    def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
        pipeline_params = {
            "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
            "frame_rate": kwargs.get('video_fps', 24),
            "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
            "is_video": True, "vae_per_channel_normalize": True,
            "prompt": kwargs.get('motion_prompt', ""), "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality"),
            "guidance_scale": kwargs.get('guidance_scale', 1.0), "stg_scale": kwargs.get('stg_scale', 0.0),
            "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
            "output_type": "latent"
        }
        if 'latents' in kwargs:
            pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
        if 'strength' in kwargs:
            pipeline_params["strength"] = kwargs['strength']
        
        if 'conditioning_items_data' in kwargs:
            final_conditioning_items = []
            for item in kwargs['conditioning_items_data']:
                item.latent_tensor = item.latent_tensor.to(worker.device)
                final_conditioning_items.append(item)
            pipeline_params["conditioning_items"] = final_conditioning_items

        if worker.is_distilled:
            fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
            if fixed_timesteps:
                pipeline_params["timesteps"] = fixed_timesteps
                pipeline_params["num_inference_steps"] = len(fixed_timesteps)
        
        callback = kwargs.get('callback')
        if callback:
            pipeline_params["callback_on_step_end"] = callback
            pipeline_params["callback_on_step_end_tensor_inputs"] = ["latents"]
        
        return pipeline_params

    def generate_latent_fragment(self, **kwargs) -> Tuple[torch.Tensor, tuple]:
        worker_to_use = self._get_next_worker()
        try:
            height, width = kwargs['height'], kwargs['width']
            padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
            padding_vals = calculate_padding(height, width, padded_h, padded_w)
            kwargs['height'], kwargs['width'] = padded_h, padded_w
            
            pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
            
            logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
            
            if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
                result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
            else:
                result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
            return result, padding_vals
        except Exception as e:
            logger.error(f"LTX POOL MANAGER: Erro durante a geração em {worker_to_use.device}: {e}", exc_info=True)
            raise e
        finally:
            if worker_to_use and worker_to_use.device.type == 'cuda':
                with torch.cuda.device(worker_to_use.device):
                    gc.collect()
                    torch.cuda.empty_cache()

    def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, tuple]:
        pass

class LtxWorker:
    """Representa uma única instância da pipeline LTX-Video em um dispositivo específico."""
    def __init__(self, device_id, ltx_config_file):
        self.cpu_device = torch.device('cpu')
        self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
        logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
        
        with open(ltx_config_file, "r") as file:
            self.config = yaml.safe_load(file)

        with open("config.yaml", 'r') as f:
            global_config = yaml.safe_load(f)
        lora_config = global_config.get('specialists', {}).get('ltx', {}).get('lora', {})
        
        self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
        models_dir = LTX_VIDEO_REPO_DIR / "models_downloaded"
        
        logger.info(f"LTX Worker ({self.device}): Preparando para carregar modelo...")
        model_filename = self.config["checkpoint_path"]
        model_path = huggingface_hub.hf_hub_download(
            repo_id="Lightricks/LTX-Video", filename=model_filename,
            local_dir=str(models_dir), local_dir_use_symlinks=False
        )
        
        self.pipeline = create_ltx_video_pipeline(
            ckpt_path=model_path, 
            precision=self.config["precision"],
            text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
            sampler=self.config["sampler"], 
            device='cpu'
        )

        if lora_config and isinstance(lora_config, list):
            logger.info(f"LTX Worker ({self.device}): Encontrados {len(lora_config)} LoRAs para carregar.")
            
            for lora_item in lora_config:
                lora_model_id = lora_item.get("model_id")
                lora_weight = lora_item.get("weight", 0.7)

                if not lora_model_id:
                    logger.warning(f"LTX Worker ({self.device}): Item de LoRA na lista sem 'model_id'. Pulando.")
                    continue

                logger.info(f"LTX Worker ({self.device}): Carregando e aplicando LoRA '{lora_model_id}' com peso {lora_weight}...")
                try:
                    self.pipeline.load_lora_weights(lora_model_id)
                    self.pipeline.fuse_lora(lora_scale=lora_weight)
                    logger.info(f"LTX Worker ({self.device}): LoRA '{lora_model_id}' fundido ao modelo com sucesso.")
                except Exception as e:
                    logger.error(f"LTX Worker ({self.device}): Falha ao carregar ou fundir o LoRA '{lora_model_id}'. Erro: {e}", exc_info=True)
        
        logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo distilled? {self.is_distilled}")
    
    def to_gpu(self):
        if self.device.type == 'cpu': return
        logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
        self.pipeline.to(self.device)
        if self.device.type == 'cuda' and can_optimize_fp8():
            logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Otimizando...")
            optimize_ltx_worker(self)
            logger.info(f"LTX Worker ({self.device}): Otimização completa.")
    
    def to_cpu(self):
        if self.device.type == 'cpu': return
        logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
        self.pipeline.to('cpu')
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    def generate_video_fragment_internal(self, **kwargs):
        return self.pipeline(**kwargs).images

def _aduc_prepare_conditioning_patch(
    self: "LTXVideoPipeline",
    conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
    init_latents: torch.Tensor,
    num_frames: int,
    height: int,
    width: int,
    vae_per_channel_normalize: bool = False,
    generator=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    if not conditioning_items:
        init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
        init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
        return init_latents, init_pixel_coords, None, 0
    
    init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
    extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
    extra_conditioning_num_latents = 0
    
    for item in conditioning_items:
        if not isinstance(item, LatentConditioningItem):
            logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
            continue
        
        media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
        media_frame_number, strength = item.media_frame_number, item.conditioning_strength
        
        if media_frame_number == 0:
            f_l, h_l, w_l = media_item_latents.shape[-3:]
            init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
            init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
        else:
            noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
            media_item_latents = torch.lerp(noise, media_item_latents, strength)
            patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
            pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
            pixel_coords[:, 0] += media_frame_number
            extra_conditioning_num_latents += patched_latents.shape[1]
            new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
            extra_conditioning_latents.append(patched_latents)
            extra_conditioning_pixel_coords.append(pixel_coords)
            extra_conditioning_mask.append(new_mask)

    init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
    init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
    init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
    init_conditioning_mask = init_conditioning_mask.squeeze(-1)
    
    if extra_conditioning_latents:
        init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
        init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
        init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
    
    return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents

# --- Instanciação Singleton ---
with open("config.yaml", 'r') as f:
    config = yaml.safe_load(f)
ltx_gpus_required = config['specialists']['ltx']['gpus_required']
ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
ltx_config_filename = config['specialists']['ltx']['config_file']
ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file_name=ltx_config_filename)
logger.info("Especialista de Vídeo (LTX) pronto.")