Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import shutil | |
| import subprocess | |
| import textwrap | |
| from pathlib import Path | |
| from typing import Optional, List, Dict | |
| from huggingface_hub import snapshot_download | |
| APP_HOME = Path(os.environ.get("APP_HOME", "/app")) | |
| SEED_REPO_DIR = APP_HOME / "SeedVR" # ajuste se o repo clonado estiver em outro path | |
| CONFIGS_DIR = APP_HOME / "configs_3b" | |
| CKPT_DIR = APP_HOME / "ckpt" / "SeedVR2-3B" | |
| MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/app/models")) | |
| REPO_ID = os.environ.get("REPO_ID_SEED", "ByteDance-Seed/SeedVR2-3B") | |
| # Arquivos essenciais (conforme app-13 e seedvr-1.sh) | |
| REQUIRED_FILES = [ | |
| "seedvr2_ema_3b.pth", # modelo principal | |
| "ema_vae.pth", # VAE | |
| "pos_emb.pt", # embeddings positivos | |
| "neg_emb.pt", # embeddings negativos | |
| ] | |
| def _env_bool(name: str, default: bool = True) -> bool: | |
| v = os.environ.get(name) | |
| return default if v is None else v.strip().lower() in ("1", "true", "yes", "on") | |
| class SeedVRRefineService: | |
| def __init__(self) -> None: | |
| self.app_home = APP_HOME | |
| self.repo_dir = SEED_REPO_DIR | |
| self.configs = CONFIGS_DIR | |
| self.ckpt_dir = CKPT_DIR | |
| # ---------------- Apex shim (sem Apex real) ---------------- | |
| def ensure_apex(self, enable_shim: bool = True) -> None: | |
| if not enable_shim: | |
| return | |
| shims_dir = Path("/app/shims/apex") | |
| shims_dir.mkdir(parents=True, exist_ok=True) | |
| norm_py = shims_dir / "normalization.py" | |
| code = textwrap.dedent(""" | |
| import torch | |
| import torch.nn as nn | |
| class FusedRMSNorm(nn.Module): | |
| def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): | |
| super().__init__() | |
| try: | |
| self.norm = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) | |
| except AttributeError: | |
| self.norm = _RMSNormFallback(normalized_shape, eps, elementwise_affine) | |
| def forward(self, x): | |
| return self.norm(x) | |
| class FusedLayerNorm(nn.LayerNorm): | |
| pass | |
| class _RMSNormFallback(nn.Module): | |
| def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): | |
| super().__init__() | |
| if isinstance(normalized_shape, int): | |
| normalized_shape = (normalized_shape,) | |
| self.normalized_shape = tuple(normalized_shape) | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if elementwise_affine: | |
| self.weight = nn.Parameter(torch.ones(self.normalized_shape)) | |
| else: | |
| self.register_parameter("weight", None) | |
| def forward(self, x): | |
| dim = tuple(range(-len(self.normalized_shape), 0)) | |
| variance = x.pow(2).mean(dim=dim, keepdim=True) | |
| x = x * torch.rsqrt(variance + self.eps) | |
| if self.weight is not None: | |
| x = x * self.weight | |
| return x | |
| """) | |
| norm_py.write_text(code) | |
| shims_root = str(Path("/app/shims")) | |
| if shims_root not in sys.path: | |
| sys.path.insert(0, shims_root) | |
| os.environ["PYTHONPATH"] = shims_root + (":" + os.environ["PYTHONPATH"] if "PYTHONPATH" in os.environ else "") | |
| def _preflight_imports(self) -> None: | |
| try: | |
| import importlib | |
| importlib.import_module("apex.normalization") | |
| print("apex shim OK (apex.normalization resolvido)") | |
| except Exception as e: | |
| raise RuntimeError(f"apex shim não resolvido: {e}") | |
| # ---------------- Modelos ---------------- | |
| def ensure_model(self, max_workers: int = 48, token: Optional[str] = None) -> str: | |
| self.ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| # já baixados? | |
| have = all((self.ckpt_dir / f).exists() for f in REQUIRED_FILES) | |
| if not have: | |
| allow = [ | |
| "seedvr2_ema_3b.pth", | |
| "ema_vae.pth", | |
| "pos_emb.pt", | |
| "neg_emb.pt", | |
| "*.md", "*.txt" | |
| ] | |
| if token: | |
| try: | |
| from huggingface_hub import login | |
| login(token=token) | |
| except Exception: | |
| pass | |
| snapshot_download( | |
| repo_id=REPO_ID, | |
| local_dir=str(self.ckpt_dir), | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| max_workers=max_workers, | |
| allow_patterns=allow | |
| ) | |
| return f"SeedVR ckpts prontos em {self.ckpt_dir}" | |
| # ---------------- Ambiente GPU ---------------- | |
| def _gpu_env(self) -> Dict[str, str]: | |
| env = os.environ.copy() | |
| env.setdefault("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7")) | |
| env.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "32") | |
| env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | |
| env.setdefault("CUDA_MODULE_LOADING", "LAZY") | |
| env.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:512,garbage_collection_threshold:0.8") | |
| # NCCL | |
| env.setdefault("NCCL_DEBUG", "INFO") | |
| env.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1") | |
| env.setdefault("NCCL_P2P_DISABLE", "0") | |
| env.setdefault("NCCL_IB_DISABLE", "1") # ajuste p/ sua topologia | |
| env.setdefault("NCCL_MIN_NCHANNELS", "8") | |
| env.setdefault("NCCL_NTHREADS", "256") | |
| # SDPA/FA | |
| env.setdefault("ENABLE_FLASH_SDP", "1") | |
| env.setdefault("ENABLE_MEMORY_EFFICIENT_SDP", "1") | |
| env.setdefault("ENABLE_MATH_SDP", "0") | |
| env.setdefault("FLASH_ATTENTION_DISABLE", "0") | |
| env.setdefault("XFORMERS_FORCE_DISABLE", "1") | |
| env.setdefault("TORCH_DTYPE", os.environ.get("TORCH_DTYPE", "bfloat16")) | |
| # PYTHONPATH com shim | |
| shims_root = "/app/shims" | |
| env["PYTHONPATH"] = shims_root + (":" + env["PYTHONPATH"] if "PYTHONPATH" in env else "") | |
| return env | |
| # ---------------- Execução de refine ---------------- | |
| def _find_refine_script(self) -> Path: | |
| # Baseado no app-13, o runner é VideoDiffusionInfer, porém aqui criamos um entrypoint "refine_cli.py" | |
| # Caso o repositório traga um script específico de sr/refine, ajuste a lista abaixo: | |
| candidates = [ | |
| self.repo_dir / "projects" / "video_diffusion_sr" / "refine_cli.py", | |
| self.repo_dir / "inference_refine.py", | |
| self.repo_dir / "inference.py", | |
| ] | |
| for p in candidates: | |
| if p.exists(): | |
| return p | |
| raise FileNotFoundError("Script de refine do SeedVR não encontrado (ajuste _find_refine_script).") | |
| def refine( | |
| self, | |
| input_path: Path, | |
| output_dir: Optional[Path] = None, | |
| upscale: float = 1.0, | |
| strength: float = 0.35, | |
| denoise: float = 0.1, | |
| t_consistency: float = 0.7, | |
| fps_out: Optional[int] = None, | |
| tile: Optional[int] = None, | |
| dtype: Optional[str] = None, | |
| extra: Optional[List[str]] = None | |
| ) -> str: | |
| input_path = Path(input_path) | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"Entrada não encontrada: {input_path}") | |
| if dtype: | |
| os.environ["TORCH_DTYPE"] = dtype | |
| out = output_dir or (self.app_home / "outputs" / "seedvr_refine") | |
| out.mkdir(parents=True, exist_ok=True) | |
| script = self._find_refine_script() | |
| cmd = [sys.executable, str(script), | |
| "--mode", "refine", | |
| "--ckpt_dir", str(self.ckpt_dir), | |
| "--input", str(input_path), | |
| "--output_dir", str(out), | |
| "--strength", str(strength), | |
| "--t_consistency", str(t_consistency), | |
| "--denoise", str(denoise), | |
| "--upscale", str(upscale)] | |
| if fps_out is not None: | |
| cmd += ["--fps", str(fps_out)] | |
| if tile: | |
| cmd += ["--tile", str(tile)] | |
| if extra: | |
| cmd += extra | |
| # Pré-checagem do shim | |
| self.ensure_apex(enable_shim=True) | |
| self._preflight_imports() | |
| env = self._gpu_env() | |
| print("CMD:", " ".join(map(str, cmd))) | |
| print("PYTHONPATH:", env.get("PYTHONPATH")) | |
| subprocess.check_call(cmd, env=env, cwd=str(self.app_home)) | |
| return str(out) | |