File size: 3,738 Bytes
fb56537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# upscaler_specialist.py
# Copyright (C) 2025 Carlos Rodrigues
# Especialista ADUC para upscaling espacial de tensores latentes.

import torch
import logging
from diffusers import LTXLatentUpsamplePipeline
from ..managers.ltx_manager import ltx_manager_singleton

logger = logging.getLogger(__name__)

class UpscalerSpecialist:
    """
    Especialista responsável por aumentar a resolução espacial de tensores latentes
    usando o LTX Video Spatial Upscaler.
    """
    def __init__(self):
        # Força uso de CUDA se disponível
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.base_vae = None
        self.pipe_upsample = None


    def _lazy_init(self):
      try:
        # Tenta usar o VAE do ltx_manager
        if ltx_manager_singleton.workers:
            candidate_vae = ltx_manager_singleton.workers[0].pipeline.vae
            if candidate_vae.__class__.__name__ == "AutoencoderKLLTXVideo":
                self.base_vae = candidate_vae
                logger.info("[Upscaler] Usando VAE do ltx_manager (AutoencoderKLLTXVideo).")
            else:
                logger.warning(f"[Upscaler] VAE incompatível: {type(candidate_vae)}. "
                               "Carregando AutoencoderKLLTXVideo manualmente...")
                from diffusers.models.autoencoders import AutoencoderKLLTXVideo
                self.base_vae = AutoencoderKLLTXVideo.from_pretrained(
                    "linoyts/LTX-Video-spatial-upscaler-0.9.8",
                    subfolder="vae",
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                ).to(self.device)
        else:
            logger.warning("[Upscaler] Nenhum worker disponível, carregando VAE manualmente...")
            from diffusers.models.autoencoders import AutoencoderKLLTXVideo
            self.base_vae = AutoencoderKLLTXVideo.from_pretrained(
                "linoyts/LTX-Video-spatial-upscaler-0.9.8",
                subfolder="vae",
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            ).to(self.device)

        # Carregar pipeline
        self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
            "linoyts/LTX-Video-spatial-upscaler-0.9.8",
            vae=self.base_vae,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device)

        logger.info("[Upscaler] Pipeline carregado com sucesso.")

      except Exception as e:
        logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}")
        self.pipe_upsample = None
        

    
    @torch.no_grad()
    def upscale(self, latents: torch.Tensor) -> torch.Tensor:
        """Aplica o upscaling 2x nos tensores latentes fornecidos."""
        self._lazy_init()
        if self.pipe_upsample is None:
            logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.")
            return latents

        try:
            logger.info(f"[Upscaler] Recebido shape {latents.shape}. Executando upscale em {self.device}...")
            
            # [CORREÇÃO FINAL] Conforme a documentação oficial, o resultado está em .frames
            result = self.pipe_upsample(latents=latents, output_type="latent")
            output_tensor = result.frames
            
            logger.info(f"[Upscaler] Upscale concluído. Novo shape: {output_tensor.shape}")
            return output_tensor
        
        except Exception as e:
            logger.error(f"[Upscaler] Erro durante upscale: {e}", exc_info=True)
            return latents


# ---------------------------
# Singleton global
# ---------------------------
upscaler_specialist_singleton = UpscalerSpecialist()