File size: 8,768 Bytes
015dd20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations
import os
import sys
import shutil
import subprocess
import textwrap
from pathlib import Path
from typing import Optional, List, Dict

from huggingface_hub import snapshot_download

APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
SEED_REPO_DIR = APP_HOME / "SeedVR"  # ajuste se o repo clonado estiver em outro path
CONFIGS_DIR = APP_HOME / "configs_3b"
CKPT_DIR = APP_HOME / "ckpt" / "SeedVR2-3B"
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/app/models"))
REPO_ID = os.environ.get("REPO_ID_SEED", "ByteDance-Seed/SeedVR2-3B")

# Arquivos essenciais (conforme app-13 e seedvr-1.sh)
REQUIRED_FILES = [
    "seedvr2_ema_3b.pth",         # modelo principal
    "ema_vae.pth",                # VAE
    "pos_emb.pt",                 # embeddings positivos
    "neg_emb.pt",                 # embeddings negativos
]

def _env_bool(name: str, default: bool = True) -> bool:
    v = os.environ.get(name)
    return default if v is None else v.strip().lower() in ("1", "true", "yes", "on")

class SeedVRRefineService:
    def __init__(self) -> None:
        self.app_home = APP_HOME
        self.repo_dir = SEED_REPO_DIR
        self.configs = CONFIGS_DIR
        self.ckpt_dir = CKPT_DIR

    # ---------------- Apex shim (sem Apex real) ----------------
    def ensure_apex(self, enable_shim: bool = True) -> None:
        if not enable_shim:
            return
        shims_dir = Path("/app/shims/apex")
        shims_dir.mkdir(parents=True, exist_ok=True)
        norm_py = shims_dir / "normalization.py"
        code = textwrap.dedent("""
            import torch
            import torch.nn as nn

            class FusedRMSNorm(nn.Module):
                def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
                    super().__init__()
                    try:
                        self.norm = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
                    except AttributeError:
                        self.norm = _RMSNormFallback(normalized_shape, eps, elementwise_affine)

                def forward(self, x):
                    return self.norm(x)

            class FusedLayerNorm(nn.LayerNorm):
                pass

            class _RMSNormFallback(nn.Module):
                def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
                    super().__init__()
                    if isinstance(normalized_shape, int):
                        normalized_shape = (normalized_shape,)
                    self.normalized_shape = tuple(normalized_shape)
                    self.eps = eps
                    self.elementwise_affine = elementwise_affine
                    if elementwise_affine:
                        self.weight = nn.Parameter(torch.ones(self.normalized_shape))
                    else:
                        self.register_parameter("weight", None)

                def forward(self, x):
                    dim = tuple(range(-len(self.normalized_shape), 0))
                    variance = x.pow(2).mean(dim=dim, keepdim=True)
                    x = x * torch.rsqrt(variance + self.eps)
                    if self.weight is not None:
                        x = x * self.weight
                    return x
        """)
        norm_py.write_text(code)
        shims_root = str(Path("/app/shims"))
        if shims_root not in sys.path:
            sys.path.insert(0, shims_root)
        os.environ["PYTHONPATH"] = shims_root + (":" + os.environ["PYTHONPATH"] if "PYTHONPATH" in os.environ else "")

    def _preflight_imports(self) -> None:
        try:
            import importlib
            importlib.import_module("apex.normalization")
            print("apex shim OK (apex.normalization resolvido)")
        except Exception as e:
            raise RuntimeError(f"apex shim não resolvido: {e}")

    # ---------------- Modelos ----------------
    def ensure_model(self, max_workers: int = 48, token: Optional[str] = None) -> str:
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        # já baixados?
        have = all((self.ckpt_dir / f).exists() for f in REQUIRED_FILES)
        if not have:
            allow = [
                "seedvr2_ema_3b.pth",
                "ema_vae.pth",
                "pos_emb.pt",
                "neg_emb.pt",
                "*.md", "*.txt"
            ]
            if token:
                try:
                    from huggingface_hub import login
                    login(token=token)
                except Exception:
                    pass
            snapshot_download(
                repo_id=REPO_ID,
                local_dir=str(self.ckpt_dir),
                local_dir_use_symlinks=False,
                resume_download=True,
                max_workers=max_workers,
                allow_patterns=allow
            )
        return f"SeedVR ckpts prontos em {self.ckpt_dir}"

    # ---------------- Ambiente GPU ----------------
    def _gpu_env(self) -> Dict[str, str]:
        env = os.environ.copy()
        env.setdefault("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7"))
        env.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "32")
        env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
        env.setdefault("CUDA_MODULE_LOADING", "LAZY")
        env.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:512,garbage_collection_threshold:0.8")
        # NCCL
        env.setdefault("NCCL_DEBUG", "INFO")
        env.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1")
        env.setdefault("NCCL_P2P_DISABLE", "0")
        env.setdefault("NCCL_IB_DISABLE", "1")      # ajuste p/ sua topologia
        env.setdefault("NCCL_MIN_NCHANNELS", "8")
        env.setdefault("NCCL_NTHREADS", "256")
        # SDPA/FA
        env.setdefault("ENABLE_FLASH_SDP", "1")
        env.setdefault("ENABLE_MEMORY_EFFICIENT_SDP", "1")
        env.setdefault("ENABLE_MATH_SDP", "0")
        env.setdefault("FLASH_ATTENTION_DISABLE", "0")
        env.setdefault("XFORMERS_FORCE_DISABLE", "1")
        env.setdefault("TORCH_DTYPE", os.environ.get("TORCH_DTYPE", "bfloat16"))
        # PYTHONPATH com shim
        shims_root = "/app/shims"
        env["PYTHONPATH"] = shims_root + (":" + env["PYTHONPATH"] if "PYTHONPATH" in env else "")
        return env

    # ---------------- Execução de refine ----------------
    def _find_refine_script(self) -> Path:
        # Baseado no app-13, o runner é VideoDiffusionInfer, porém aqui criamos um entrypoint "refine_cli.py"
        # Caso o repositório traga um script específico de sr/refine, ajuste a lista abaixo:
        candidates = [
            self.repo_dir / "projects" / "video_diffusion_sr" / "refine_cli.py",
            self.repo_dir / "inference_refine.py",
            self.repo_dir / "inference.py",
        ]
        for p in candidates:
            if p.exists():
                return p
        raise FileNotFoundError("Script de refine do SeedVR não encontrado (ajuste _find_refine_script).")

    def refine(
        self,
        input_path: Path,
        output_dir: Optional[Path] = None,
        upscale: float = 1.0,
        strength: float = 0.35,
        denoise: float = 0.1,
        t_consistency: float = 0.7,
        fps_out: Optional[int] = None,
        tile: Optional[int] = None,
        dtype: Optional[str] = None,
        extra: Optional[List[str]] = None
    ) -> str:
        input_path = Path(input_path)
        if not input_path.exists():
            raise FileNotFoundError(f"Entrada não encontrada: {input_path}")

        if dtype:
            os.environ["TORCH_DTYPE"] = dtype

        out = output_dir or (self.app_home / "outputs" / "seedvr_refine")
        out.mkdir(parents=True, exist_ok=True)

        script = self._find_refine_script()
        cmd = [sys.executable, str(script),
               "--mode", "refine",
               "--ckpt_dir", str(self.ckpt_dir),
               "--input", str(input_path),
               "--output_dir", str(out),
               "--strength", str(strength),
               "--t_consistency", str(t_consistency),
               "--denoise", str(denoise),
               "--upscale", str(upscale)]
        if fps_out is not None:
            cmd += ["--fps", str(fps_out)]
        if tile:
            cmd += ["--tile", str(tile)]
        if extra:
            cmd += extra

        # Pré-checagem do shim
        self.ensure_apex(enable_shim=True)
        self._preflight_imports()
        env = self._gpu_env()

        print("CMD:", " ".join(map(str, cmd)))
        print("PYTHONPATH:", env.get("PYTHONPATH"))
        subprocess.check_call(cmd, env=env, cwd=str(self.app_home))
        return str(out)