carlex3321's picture
Create seedvr.py
015dd20 verified
raw
history blame
8.77 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import sys
import shutil
import subprocess
import textwrap
from pathlib import Path
from typing import Optional, List, Dict
from huggingface_hub import snapshot_download
APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
SEED_REPO_DIR = APP_HOME / "SeedVR" # ajuste se o repo clonado estiver em outro path
CONFIGS_DIR = APP_HOME / "configs_3b"
CKPT_DIR = APP_HOME / "ckpt" / "SeedVR2-3B"
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/app/models"))
REPO_ID = os.environ.get("REPO_ID_SEED", "ByteDance-Seed/SeedVR2-3B")
# Arquivos essenciais (conforme app-13 e seedvr-1.sh)
REQUIRED_FILES = [
"seedvr2_ema_3b.pth", # modelo principal
"ema_vae.pth", # VAE
"pos_emb.pt", # embeddings positivos
"neg_emb.pt", # embeddings negativos
]
def _env_bool(name: str, default: bool = True) -> bool:
v = os.environ.get(name)
return default if v is None else v.strip().lower() in ("1", "true", "yes", "on")
class SeedVRRefineService:
def __init__(self) -> None:
self.app_home = APP_HOME
self.repo_dir = SEED_REPO_DIR
self.configs = CONFIGS_DIR
self.ckpt_dir = CKPT_DIR
# ---------------- Apex shim (sem Apex real) ----------------
def ensure_apex(self, enable_shim: bool = True) -> None:
if not enable_shim:
return
shims_dir = Path("/app/shims/apex")
shims_dir.mkdir(parents=True, exist_ok=True)
norm_py = shims_dir / "normalization.py"
code = textwrap.dedent("""
import torch
import torch.nn as nn
class FusedRMSNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
try:
self.norm = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
except AttributeError:
self.norm = _RMSNormFallback(normalized_shape, eps, elementwise_affine)
def forward(self, x):
return self.norm(x)
class FusedLayerNorm(nn.LayerNorm):
pass
class _RMSNormFallback(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
else:
self.register_parameter("weight", None)
def forward(self, x):
dim = tuple(range(-len(self.normalized_shape), 0))
variance = x.pow(2).mean(dim=dim, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight is not None:
x = x * self.weight
return x
""")
norm_py.write_text(code)
shims_root = str(Path("/app/shims"))
if shims_root not in sys.path:
sys.path.insert(0, shims_root)
os.environ["PYTHONPATH"] = shims_root + (":" + os.environ["PYTHONPATH"] if "PYTHONPATH" in os.environ else "")
def _preflight_imports(self) -> None:
try:
import importlib
importlib.import_module("apex.normalization")
print("apex shim OK (apex.normalization resolvido)")
except Exception as e:
raise RuntimeError(f"apex shim não resolvido: {e}")
# ---------------- Modelos ----------------
def ensure_model(self, max_workers: int = 48, token: Optional[str] = None) -> str:
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
# já baixados?
have = all((self.ckpt_dir / f).exists() for f in REQUIRED_FILES)
if not have:
allow = [
"seedvr2_ema_3b.pth",
"ema_vae.pth",
"pos_emb.pt",
"neg_emb.pt",
"*.md", "*.txt"
]
if token:
try:
from huggingface_hub import login
login(token=token)
except Exception:
pass
snapshot_download(
repo_id=REPO_ID,
local_dir=str(self.ckpt_dir),
local_dir_use_symlinks=False,
resume_download=True,
max_workers=max_workers,
allow_patterns=allow
)
return f"SeedVR ckpts prontos em {self.ckpt_dir}"
# ---------------- Ambiente GPU ----------------
def _gpu_env(self) -> Dict[str, str]:
env = os.environ.copy()
env.setdefault("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7"))
env.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "32")
env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
env.setdefault("CUDA_MODULE_LOADING", "LAZY")
env.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:512,garbage_collection_threshold:0.8")
# NCCL
env.setdefault("NCCL_DEBUG", "INFO")
env.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1")
env.setdefault("NCCL_P2P_DISABLE", "0")
env.setdefault("NCCL_IB_DISABLE", "1") # ajuste p/ sua topologia
env.setdefault("NCCL_MIN_NCHANNELS", "8")
env.setdefault("NCCL_NTHREADS", "256")
# SDPA/FA
env.setdefault("ENABLE_FLASH_SDP", "1")
env.setdefault("ENABLE_MEMORY_EFFICIENT_SDP", "1")
env.setdefault("ENABLE_MATH_SDP", "0")
env.setdefault("FLASH_ATTENTION_DISABLE", "0")
env.setdefault("XFORMERS_FORCE_DISABLE", "1")
env.setdefault("TORCH_DTYPE", os.environ.get("TORCH_DTYPE", "bfloat16"))
# PYTHONPATH com shim
shims_root = "/app/shims"
env["PYTHONPATH"] = shims_root + (":" + env["PYTHONPATH"] if "PYTHONPATH" in env else "")
return env
# ---------------- Execução de refine ----------------
def _find_refine_script(self) -> Path:
# Baseado no app-13, o runner é VideoDiffusionInfer, porém aqui criamos um entrypoint "refine_cli.py"
# Caso o repositório traga um script específico de sr/refine, ajuste a lista abaixo:
candidates = [
self.repo_dir / "projects" / "video_diffusion_sr" / "refine_cli.py",
self.repo_dir / "inference_refine.py",
self.repo_dir / "inference.py",
]
for p in candidates:
if p.exists():
return p
raise FileNotFoundError("Script de refine do SeedVR não encontrado (ajuste _find_refine_script).")
def refine(
self,
input_path: Path,
output_dir: Optional[Path] = None,
upscale: float = 1.0,
strength: float = 0.35,
denoise: float = 0.1,
t_consistency: float = 0.7,
fps_out: Optional[int] = None,
tile: Optional[int] = None,
dtype: Optional[str] = None,
extra: Optional[List[str]] = None
) -> str:
input_path = Path(input_path)
if not input_path.exists():
raise FileNotFoundError(f"Entrada não encontrada: {input_path}")
if dtype:
os.environ["TORCH_DTYPE"] = dtype
out = output_dir or (self.app_home / "outputs" / "seedvr_refine")
out.mkdir(parents=True, exist_ok=True)
script = self._find_refine_script()
cmd = [sys.executable, str(script),
"--mode", "refine",
"--ckpt_dir", str(self.ckpt_dir),
"--input", str(input_path),
"--output_dir", str(out),
"--strength", str(strength),
"--t_consistency", str(t_consistency),
"--denoise", str(denoise),
"--upscale", str(upscale)]
if fps_out is not None:
cmd += ["--fps", str(fps_out)]
if tile:
cmd += ["--tile", str(tile)]
if extra:
cmd += extra
# Pré-checagem do shim
self.ensure_apex(enable_shim=True)
self._preflight_imports()
env = self._gpu_env()
print("CMD:", " ".join(map(str, cmd)))
print("PYTHONPATH:", env.get("PYTHONPATH"))
subprocess.check_call(cmd, env=env, cwd=str(self.app_home))
return str(out)