Carlexxx
feat: ✨ aBINC 2.2
fb56537
# aduc_framework/managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Versão 11.0.0 (Otimização com FP16 Checkpoint)
#
# Utiliza o checkpoint FP16 otimizado para reduzir significativamente o uso
# de VRAM, mantendo a arquitetura de "monkey patch" para desativar o
# paralelismo problemático e garantir a estabilidade em modo de GPU única.
import torch
import os
import gc
import yaml
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import mediapy
from einops import rearrange
import shutil
from omegaconf import OmegaConf
from typing import Generator, Dict, Any
from ..tools.hardware_manager import hardware_manager
logger = logging.getLogger(__name__)
# Define os caminhos base para evitar hardcoding
APP_ROOT = Path(os.getcwd())
DEPS_DIR = APP_ROOT / "deps"
SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
def _load_file_from_url(url, model_dir='./', file_name=None):
"""Helper para baixar arquivos de modelo e checkpoints."""
os.makedirs(model_dir, exist_ok=True)
filename = file_name or os.path.basename(urlparse(url).path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
logger.info(f"Baixando arquivo do SeedVR: {filename}...")
download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
return cached_file
class SeedVrManager:
"""Gerencia uma única instância do pipeline SeedVR em uma GPU dedicada e isolada."""
def __init__(self, device_id: str):
self.global_device_id = device_id
self.local_device_name = 'cuda:0' # O que o processo enxergará
self.gpu_index = self.global_device_id.split(':')[-1]
self.runner = None
self._check_and_run_global_setup()
logger.info(f"SeedVR Manager (FP16 Optimized) inicializado para operar na GPU {self.global_device_id}.")
@staticmethod
def _check_and_run_global_setup():
"""Executa o setup de arquivos e aplica o patch para desativar o paralelismo."""
setup_flag = DEPS_DIR / "seedvr.setup.complete"
if str(APP_ROOT) not in sys.path: sys.path.insert(0, str(APP_ROOT))
try:
from common import decorators
import functools
def _passthrough_decorator(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
return func(*args, **kwargs)
return wrapped
decorators.master_only = _passthrough_decorator
logger.info("Monkey patch aplicado com sucesso em 'common.decorators.master_only' para desativar o paralelismo.")
except Exception as e:
logger.error(f"Falha ao aplicar o monkey patch para o SeedVR: {e}", exc_info=True)
if setup_flag.exists(): return True
logger.info("--- Iniciando Setup Global do SeedVR (primeira execução) ---")
if not SEEDVR_SPACE_DIR.exists():
DEPS_DIR.mkdir(exist_ok=True, parents=True)
logger.info(f"Clonando repositório do SeedVR de {SEEDVR_SPACE_URL}...")
subprocess.run(["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)], check=True)
required_dirs = ["projects", "common", "models", "configs_3b", "data"]
for dirname in required_dirs:
source, target = SEEDVR_SPACE_DIR / dirname, APP_ROOT / dirname
if not target.exists():
logger.info(f"Copiando diretório '{dirname}' do SeedVR para a raiz do projeto...")
shutil.copytree(source, target)
try:
import apex
except ImportError:
logger.info("Dependência 'apex' não encontrada. Instalando a partir do wheel fornecido...")
apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
subprocess.run([sys.executable, '-m', 'pip', 'install', apex_wheel_path], check=True)
ckpt_dir = APP_ROOT / 'ckpts'
ckpt_dir.mkdir(exist_ok=True)
# >>> ALTERAÇÃO PRINCIPAL: Usando o checkpoint FP16 otimizado <<<
model_urls = {
'vae': 'https://huggingface.co/batuhanince/seedvr_3b_fp16/resolve/main/ema_vae_fp16.safetensors',
'dit_3b': 'https://huggingface.co/batuhanince/seedvr_3b_fp16/resolve/main/seedvr2_ema_3b_fp16.safetensors',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
}
for name, url in model_urls.items():
_load_file_from_url(url=url, model_dir=str(ckpt_dir))
setup_flag.touch()
logger.info("--- Setup Global do SeedVR Concluído ---")
def _initialize_runner(self):
if self.runner is not None: return
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpu_index
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
logger.info(f"Manager na GPU {self.global_device_id}: Inicializando runner SeedVR 3B com checkpoint FP16...")
config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
# >>> O caminho agora aponta para o novo arquivo .safetensors <<<
checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b_fp16.safetensors'
config = load_config(str(config_path))
self.runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(self.runner.config, False)
self.runner.configure_dit_model(device=self.local_device_name, checkpoint=str(checkpoint_path))
self.runner.configure_vae_model()
logger.info(f"Manager na GPU {self.global_device_id}: Runner 3B (FP16) pronto na VRAM.")
def _unload_runner(self):
if self.runner is not None:
del self.runner; self.runner = None
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
logger.info(f"Manager na GPU {self.global_device_id}: Runner descarregado da VRAM.")
if 'CUDA_VISIBLE_DEVICES' in os.environ:
del os.environ['CUDA_VISIBLE_DEVICES']
def process_video(self, input_video_path: str, output_video_path: str, prompt: str, steps: int = 100, seed: int = 666) -> Generator[Dict[str, Any], None, None]:
try:
self._initialize_runner()
yield {"progress": 0.1, "desc": "Runner inicializado. Lendo vídeo..."}
device = torch.device(self.local_device_name)
from common.seed import set_seed
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
set_seed(seed, same_across_ranks=True)
self.runner.config.diffusion.timesteps.sampling.steps = steps
self.runner.configure_diffusion()
video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
res_h, res_w = video_tensor.shape[-2:]
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w"),
])
cond_latents = [video_transform(video_tensor.to(device))]
yield {"progress": 0.2, "desc": "Encodificando para o espaço latente..."}
self.runner.dit.to("cpu"); self.runner.vae.to(device)
cond_latents = self.runner.vae_encode(cond_latents)
self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache(); self.runner.dit.to(device)
pos_emb = torch.load(APP_ROOT / 'ckpts' / 'pos_emb.pt').to(device)
neg_emb = torch.load(APP_ROOT / 'ckpts' / 'neg_emb.pt').to(device)
text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
noises = [torch.randn_like(latent) for latent in cond_latents]
conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
yield {"progress": 0.5, "desc": "Aplicando difusão para super-resolução..."}
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
yield {"progress": 0.8, "desc": "Decodificando para o espaço de pixel..."}
self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache(); self.runner.vae.to(device)
samples = self.runner.vae_decode(video_tensors)
final_sample, input_video_sample = samples[0], cond_latents[0]
if final_sample.shape[1] < input_video_sample.shape[1]:
input_video_sample = input_video_sample[:, :final_sample.shape[1]]
final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
final_sample = rearrange(final_sample, "t c h w -> t h w c")
final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
yield {"progress": 0.9, "desc": "Escrevendo arquivo de vídeo final..."}
mediapy.write_video(output_video_path, final_sample_np, fps=24)
yield {"progress": 1.0, "final_path": output_video_path}
finally:
self._unload_runner()
# --- Instanciação Singleton ---
class SeedVrPlaceholder:
def process_video(self, input_video_path, *args, **kwargs):
logger.warning("SeedVR está desabilitado (gpus_required: 0). Pulando etapa de masterização HD.")
yield {"progress": 1.0, "final_path": input_video_path}
try:
with open("config.yaml", 'r') as f: config = yaml.safe_load(f)
seedvr_gpus_required = config['specialists'].get('seedvr', {}).get('gpus_required', 0)
if seedvr_gpus_required > 0:
seedvr_device_ids = hardware_manager.allocate_gpus('SeedVR', seedvr_gpus_required)
if seedvr_device_ids and 'cpu' not in seedvr_device_ids:
device_to_use = seedvr_device_ids[0]
seedvr_manager_singleton = SeedVrManager(device_id=device_to_use)
logger.info(f"Especialista de Masterização HD (SeedVR FP16) pronto para usar a GPU {device_to_use}.")
else:
seedvr_manager_singleton = SeedVrPlaceholder()
else:
seedvr_manager_singleton = SeedVrPlaceholder()
except Exception as e:
logger.critical(f"Falha CRÍTICA ao inicializar o SeedVrManager: {e}", exc_info=True)
seedvr_manager_singleton = SeedVrPlaceholder()