#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import os import sys import shutil import subprocess import textwrap from pathlib import Path from typing import Optional, List, Dict from huggingface_hub import snapshot_download APP_HOME = Path(os.environ.get("APP_HOME", "/app")) SEED_REPO_DIR = APP_HOME / "SeedVR" # ajuste se o repo clonado estiver em outro path CONFIGS_DIR = APP_HOME / "configs_3b" CKPT_DIR = APP_HOME / "ckpt" / "SeedVR2-3B" MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/app/models")) REPO_ID = os.environ.get("REPO_ID_SEED", "ByteDance-Seed/SeedVR2-3B") # Arquivos essenciais (conforme app-13 e seedvr-1.sh) REQUIRED_FILES = [ "seedvr2_ema_3b.pth", # modelo principal "ema_vae.pth", # VAE "pos_emb.pt", # embeddings positivos "neg_emb.pt", # embeddings negativos ] def _env_bool(name: str, default: bool = True) -> bool: v = os.environ.get(name) return default if v is None else v.strip().lower() in ("1", "true", "yes", "on") class SeedVRRefineService: def __init__(self) -> None: self.app_home = APP_HOME self.repo_dir = SEED_REPO_DIR self.configs = CONFIGS_DIR self.ckpt_dir = CKPT_DIR # ---------------- Apex shim (sem Apex real) ---------------- def ensure_apex(self, enable_shim: bool = True) -> None: if not enable_shim: return shims_dir = Path("/app/shims/apex") shims_dir.mkdir(parents=True, exist_ok=True) norm_py = shims_dir / "normalization.py" code = textwrap.dedent(""" import torch import torch.nn as nn class FusedRMSNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): super().__init__() try: self.norm = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) except AttributeError: self.norm = _RMSNormFallback(normalized_shape, eps, elementwise_affine) def forward(self, x): return self.norm(x) class FusedLayerNorm(nn.LayerNorm): pass class _RMSNormFallback(nn.Module): def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): super().__init__() if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if elementwise_affine: self.weight = nn.Parameter(torch.ones(self.normalized_shape)) else: self.register_parameter("weight", None) def forward(self, x): dim = tuple(range(-len(self.normalized_shape), 0)) variance = x.pow(2).mean(dim=dim, keepdim=True) x = x * torch.rsqrt(variance + self.eps) if self.weight is not None: x = x * self.weight return x """) norm_py.write_text(code) shims_root = str(Path("/app/shims")) if shims_root not in sys.path: sys.path.insert(0, shims_root) os.environ["PYTHONPATH"] = shims_root + (":" + os.environ["PYTHONPATH"] if "PYTHONPATH" in os.environ else "") def _preflight_imports(self) -> None: try: import importlib importlib.import_module("apex.normalization") print("apex shim OK (apex.normalization resolvido)") except Exception as e: raise RuntimeError(f"apex shim não resolvido: {e}") # ---------------- Modelos ---------------- def ensure_model(self, max_workers: int = 48, token: Optional[str] = None) -> str: self.ckpt_dir.mkdir(parents=True, exist_ok=True) # já baixados? have = all((self.ckpt_dir / f).exists() for f in REQUIRED_FILES) if not have: allow = [ "seedvr2_ema_3b.pth", "ema_vae.pth", "pos_emb.pt", "neg_emb.pt", "*.md", "*.txt" ] if token: try: from huggingface_hub import login login(token=token) except Exception: pass snapshot_download( repo_id=REPO_ID, local_dir=str(self.ckpt_dir), local_dir_use_symlinks=False, resume_download=True, max_workers=max_workers, allow_patterns=allow ) return f"SeedVR ckpts prontos em {self.ckpt_dir}" # ---------------- Ambiente GPU ---------------- def _gpu_env(self) -> Dict[str, str]: env = os.environ.copy() env.setdefault("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7")) env.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "32") env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") env.setdefault("CUDA_MODULE_LOADING", "LAZY") env.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:512,garbage_collection_threshold:0.8") # NCCL env.setdefault("NCCL_DEBUG", "INFO") env.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1") env.setdefault("NCCL_P2P_DISABLE", "0") env.setdefault("NCCL_IB_DISABLE", "1") # ajuste p/ sua topologia env.setdefault("NCCL_MIN_NCHANNELS", "8") env.setdefault("NCCL_NTHREADS", "256") # SDPA/FA env.setdefault("ENABLE_FLASH_SDP", "1") env.setdefault("ENABLE_MEMORY_EFFICIENT_SDP", "1") env.setdefault("ENABLE_MATH_SDP", "0") env.setdefault("FLASH_ATTENTION_DISABLE", "0") env.setdefault("XFORMERS_FORCE_DISABLE", "1") env.setdefault("TORCH_DTYPE", os.environ.get("TORCH_DTYPE", "bfloat16")) # PYTHONPATH com shim shims_root = "/app/shims" env["PYTHONPATH"] = shims_root + (":" + env["PYTHONPATH"] if "PYTHONPATH" in env else "") return env # ---------------- Execução de refine ---------------- def _find_refine_script(self) -> Path: # Baseado no app-13, o runner é VideoDiffusionInfer, porém aqui criamos um entrypoint "refine_cli.py" # Caso o repositório traga um script específico de sr/refine, ajuste a lista abaixo: candidates = [ self.repo_dir / "projects" / "video_diffusion_sr" / "refine_cli.py", self.repo_dir / "inference_refine.py", self.repo_dir / "inference.py", ] for p in candidates: if p.exists(): return p raise FileNotFoundError("Script de refine do SeedVR não encontrado (ajuste _find_refine_script).") def refine( self, input_path: Path, output_dir: Optional[Path] = None, upscale: float = 1.0, strength: float = 0.35, denoise: float = 0.1, t_consistency: float = 0.7, fps_out: Optional[int] = None, tile: Optional[int] = None, dtype: Optional[str] = None, extra: Optional[List[str]] = None ) -> str: input_path = Path(input_path) if not input_path.exists(): raise FileNotFoundError(f"Entrada não encontrada: {input_path}") if dtype: os.environ["TORCH_DTYPE"] = dtype out = output_dir or (self.app_home / "outputs" / "seedvr_refine") out.mkdir(parents=True, exist_ok=True) script = self._find_refine_script() cmd = [sys.executable, str(script), "--mode", "refine", "--ckpt_dir", str(self.ckpt_dir), "--input", str(input_path), "--output_dir", str(out), "--strength", str(strength), "--t_consistency", str(t_consistency), "--denoise", str(denoise), "--upscale", str(upscale)] if fps_out is not None: cmd += ["--fps", str(fps_out)] if tile: cmd += ["--tile", str(tile)] if extra: cmd += extra # Pré-checagem do shim self.ensure_apex(enable_shim=True) self._preflight_imports() env = self._gpu_env() print("CMD:", " ".join(map(str, cmd))) print("PYTHONPATH:", env.get("PYTHONPATH")) subprocess.check_call(cmd, env=env, cwd=str(self.app_home)) return str(out)