x2XcarleX2x's picture
Update aduc_framework/managers/vae_wan_manager.py
e7d2ed1 verified
# aduc_framework/managers/vae_wan_manager.py (Versão Definitiva Completa)
import torch
import logging
import yaml
from PIL import Image
import numpy as np
from typing import List, Optional
import sys
import os
# --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO ---
# Adiciona o diretório do Wan2.2 ao sys.path para encontrar módulos customizados.
WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2'))
if WAN_REPO_PATH not in sys.path:
sys.path.insert(0, WAN_REPO_PATH)
logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para o VaeWanManager.")
# --- FIM DA CORREÇÃO DE IMPORTAÇÃO ---
# Ferramentas da nossa arquitetura
from ..tools.hardware_manager import hardware_manager
# --- IMPORTAÇÕES CORRIGIDAS ---
# Importa a classe de VAE customizada do Wan2.2 e a função auxiliar oficial.
from diffusers.models import AutoencoderKLWan
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
logger = logging.getLogger(__name__)
class VaeWanManager:
"""
Especialista VAE dedicado e persistente para a pipeline Wan2.2.
Gerencia o ciclo de vida do AutoencoderKLWan em uma GPU dedicada,
garantindo a tradução correta entre o espaço de pixels e o espaço latente.
"""
def __init__(self):
self.device = None
self.vae: Optional[AutoencoderKLWan] = None
self.dtype = None
self.config = self._load_config()
if self.config:
gpus_required = self.config.get('gpus_required', 0)
if gpus_required > 0:
self.device = hardware_manager.allocate_gpus('VaeWanManager', gpus_required)[0]
logger.info(f"VaeWanManager: GPU dedicada '{self.device}' reservada.")
else:
self.device = torch.device('cpu')
logger.warning("VaeWanManager: Nenhuma GPU dedicada foi alocada.")
else:
logger.warning("Configuração para 'vae_wan' não encontrada em config.yaml.")
def _load_config(self):
"""Carrega a configuração específica deste manager."""
try:
with open("config.yaml", 'r', encoding='utf-8') as f:
return yaml.safe_load(f).get('specialists', {}).get('vae_wan', {})
except FileNotFoundError:
logger.error("config.yaml não encontrado.")
return None
def _lazy_init(self):
"""Carrega o modelo VAE do Wan2.2 para a VRAM no primeiro uso."""
if self.vae is not None:
return
if not self.device or not self.config:
raise RuntimeError("VaeWanManager não pode ser inicializado.")
logger.info(f"VAE-WAN MANAGER ({self.device}): Carregando VAE do Wan2.2...")
model_id = self.config.get("model_id")
local_model_path = f"/app/models/{model_id}" # Carrega do nosso diretório local
self.dtype = torch.float32 # VAEs são mais estáveis em FP32
try:
# Carrega a classe CORRETA (AutoencoderKLWan) do caminho LOCAL.
self.vae = AutoencoderKLWan.from_pretrained(
local_model_path, subfolder="vae", torch_dtype=self.dtype
).to(self.device)
self.vae.eval()
logger.info(f"VAE-WAN MANAGER ({self.device}): VAE do Wan2.2 pronto e 'quente' na VRAM.")
except Exception as e:
logger.error(f"VAE-WAN MANAGER: Falha CRÍTICA ao carregar o VAE: {e}", exc_info=True)
self.vae = None
raise e
def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
"""Converte uma imagem PIL para o formato de tensor 5D esperado pelo VAE de vídeo."""
from PIL import ImageOps
img = pil_image.convert("RGB")
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
image_np = np.array(processed_img).astype(np.float32) / 255.0
# Converte para (B, C, H, W)
tensor_4d = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
tensor_4d_normalized = (tensor_4d * 2.0) - 1.0
# Adiciona a dimensão de "frame" para criar um tensor 5D (B, C, F, H, W)
tensor_5d = tensor_4d_normalized.unsqueeze(2)
return tensor_5d
@torch.no_grad()
def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
"""Codifica um lote de imagens PIL para o espaço latente do Wan."""
self._lazy_init()
if not self.vae:
raise RuntimeError("O VAE do WanManager não foi carregado.")
latents_list = []
for img in pil_images:
# A função de pré-processamento agora retorna o tensor 5D correto
pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype)
encoder_output = self.vae.encode(pixel_tensor_gpu)
latents = retrieve_latents(encoder_output)
latents_list.append(latents.cpu())
return latents_list
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor) -> torch.Tensor:
"""Decodifica um tensor latente do Wan para o espaço de pixels."""
self._lazy_init()
if not self.vae:
raise RuntimeError("O VAE do WanManager não foi carregado.")
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
# Acessa a saída através do atributo .sample
decode_output = self.vae.decode(latent_tensor_gpu)
pixels = decode_output.sample
return pixels.cpu()
# --- Instância Singleton ---
vae_wan_manager_singleton = VaeWanManager()