Spaces:
Paused
Paused
| # aduc_framework/managers/vae_wan_manager.py (Versão Definitiva Completa) | |
| import torch | |
| import logging | |
| import yaml | |
| from PIL import Image | |
| import numpy as np | |
| from typing import List, Optional | |
| import sys | |
| import os | |
| # --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO --- | |
| # Adiciona o diretório do Wan2.2 ao sys.path para encontrar módulos customizados. | |
| WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2')) | |
| if WAN_REPO_PATH not in sys.path: | |
| sys.path.insert(0, WAN_REPO_PATH) | |
| logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para o VaeWanManager.") | |
| # --- FIM DA CORREÇÃO DE IMPORTAÇÃO --- | |
| # Ferramentas da nossa arquitetura | |
| from ..tools.hardware_manager import hardware_manager | |
| # --- IMPORTAÇÕES CORRIGIDAS --- | |
| # Importa a classe de VAE customizada do Wan2.2 e a função auxiliar oficial. | |
| from diffusers.models import AutoencoderKLWan | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents | |
| logger = logging.getLogger(__name__) | |
| class VaeWanManager: | |
| """ | |
| Especialista VAE dedicado e persistente para a pipeline Wan2.2. | |
| Gerencia o ciclo de vida do AutoencoderKLWan em uma GPU dedicada, | |
| garantindo a tradução correta entre o espaço de pixels e o espaço latente. | |
| """ | |
| def __init__(self): | |
| self.device = None | |
| self.vae: Optional[AutoencoderKLWan] = None | |
| self.dtype = None | |
| self.config = self._load_config() | |
| if self.config: | |
| gpus_required = self.config.get('gpus_required', 0) | |
| if gpus_required > 0: | |
| self.device = hardware_manager.allocate_gpus('VaeWanManager', gpus_required)[0] | |
| logger.info(f"VaeWanManager: GPU dedicada '{self.device}' reservada.") | |
| else: | |
| self.device = torch.device('cpu') | |
| logger.warning("VaeWanManager: Nenhuma GPU dedicada foi alocada.") | |
| else: | |
| logger.warning("Configuração para 'vae_wan' não encontrada em config.yaml.") | |
| def _load_config(self): | |
| """Carrega a configuração específica deste manager.""" | |
| try: | |
| with open("config.yaml", 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f).get('specialists', {}).get('vae_wan', {}) | |
| except FileNotFoundError: | |
| logger.error("config.yaml não encontrado.") | |
| return None | |
| def _lazy_init(self): | |
| """Carrega o modelo VAE do Wan2.2 para a VRAM no primeiro uso.""" | |
| if self.vae is not None: | |
| return | |
| if not self.device or not self.config: | |
| raise RuntimeError("VaeWanManager não pode ser inicializado.") | |
| logger.info(f"VAE-WAN MANAGER ({self.device}): Carregando VAE do Wan2.2...") | |
| model_id = self.config.get("model_id") | |
| local_model_path = f"/app/models/{model_id}" # Carrega do nosso diretório local | |
| self.dtype = torch.float32 # VAEs são mais estáveis em FP32 | |
| try: | |
| # Carrega a classe CORRETA (AutoencoderKLWan) do caminho LOCAL. | |
| self.vae = AutoencoderKLWan.from_pretrained( | |
| local_model_path, subfolder="vae", torch_dtype=self.dtype | |
| ).to(self.device) | |
| self.vae.eval() | |
| logger.info(f"VAE-WAN MANAGER ({self.device}): VAE do Wan2.2 pronto e 'quente' na VRAM.") | |
| except Exception as e: | |
| logger.error(f"VAE-WAN MANAGER: Falha CRÍTICA ao carregar o VAE: {e}", exc_info=True) | |
| self.vae = None | |
| raise e | |
| def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor: | |
| """Converte uma imagem PIL para o formato de tensor 5D esperado pelo VAE de vídeo.""" | |
| from PIL import ImageOps | |
| img = pil_image.convert("RGB") | |
| processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS) | |
| image_np = np.array(processed_img).astype(np.float32) / 255.0 | |
| # Converte para (B, C, H, W) | |
| tensor_4d = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0) | |
| tensor_4d_normalized = (tensor_4d * 2.0) - 1.0 | |
| # Adiciona a dimensão de "frame" para criar um tensor 5D (B, C, F, H, W) | |
| tensor_5d = tensor_4d_normalized.unsqueeze(2) | |
| return tensor_5d | |
| def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]: | |
| """Codifica um lote de imagens PIL para o espaço latente do Wan.""" | |
| self._lazy_init() | |
| if not self.vae: | |
| raise RuntimeError("O VAE do WanManager não foi carregado.") | |
| latents_list = [] | |
| for img in pil_images: | |
| # A função de pré-processamento agora retorna o tensor 5D correto | |
| pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype) | |
| encoder_output = self.vae.encode(pixel_tensor_gpu) | |
| latents = retrieve_latents(encoder_output) | |
| latents_list.append(latents.cpu()) | |
| return latents_list | |
| def decode(self, latent_tensor: torch.Tensor) -> torch.Tensor: | |
| """Decodifica um tensor latente do Wan para o espaço de pixels.""" | |
| self._lazy_init() | |
| if not self.vae: | |
| raise RuntimeError("O VAE do WanManager não foi carregado.") | |
| latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) | |
| # Acessa a saída através do atributo .sample | |
| decode_output = self.vae.decode(latent_tensor_gpu) | |
| pixels = decode_output.sample | |
| return pixels.cpu() | |
| # --- Instância Singleton --- | |
| vae_wan_manager_singleton = VaeWanManager() |