Spaces:
Paused
Paused
Update aduc_framework/managers/wan_manager.py
Browse files- aduc_framework/managers/wan_manager.py +162 -116
aduc_framework/managers/wan_manager.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# aduc_framework/managers/wan_manager.py
|
| 2 |
-
# WanManager
|
| 3 |
|
| 4 |
import os
|
| 5 |
import platform
|
|
@@ -13,15 +13,12 @@ import numpy as np
|
|
| 13 |
import torch
|
| 14 |
from PIL import Image
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 18 |
-
|
| 19 |
-
# SDPA / FlashAttention context
|
| 20 |
try:
|
| 21 |
-
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 22 |
_SDPA_NEW = True
|
| 23 |
except Exception:
|
| 24 |
-
from torch.backends.cuda import sdp_kernel as _legacy_sdp
|
| 25 |
_SDPA_NEW = False
|
| 26 |
|
| 27 |
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
@@ -29,88 +26,108 @@ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
| 29 |
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
|
| 30 |
from diffusers.utils.export_utils import export_to_video
|
| 31 |
|
| 32 |
-
from aduc_framework.utils.callbacks import DenoiseStepLogger
|
| 33 |
-
|
| 34 |
|
| 35 |
class WanManager:
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
-
|
| 41 |
-
|
| 42 |
-
-
|
| 43 |
-
|
| 44 |
-
- **Performance Otimizada:** Utiliza LoRA Lightning fundida para geração rápida e
|
| 45 |
-
aproveita o SDPA (Scaled Dot Product Attention) com uma cadeia de fallback
|
| 46 |
-
inteligente (Flash -> Efficient -> Math) para máxima velocidade.
|
| 47 |
-
- **Validação de Parâmetros Robusta:** Implementa regras de negócio para validar e
|
| 48 |
-
corrigir o número total de frames (`4n+1`) e a posição do frame de controle
|
| 49 |
-
(`8n+1` com buffers de segurança), garantindo estabilidade e resultados previsíveis.
|
| 50 |
-
- **Depuração Visual:** Integra um sistema de callbacks para capturar o processo de
|
| 51 |
-
denoising, gerando um vídeo de depuração e uma grade de imagens com cada passo.
|
| 52 |
"""
|
| 53 |
|
| 54 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 55 |
TRANSFORMER_ID = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers"
|
| 56 |
|
|
|
|
| 57 |
MAX_DIMENSION = 832
|
| 58 |
MIN_DIMENSION = 480
|
| 59 |
DIMENSION_MULTIPLE = 16
|
| 60 |
SQUARE_SIZE = 480
|
|
|
|
| 61 |
FIXED_FPS = 16
|
| 62 |
MIN_FRAMES_MODEL = 8
|
| 63 |
MAX_FRAMES_MODEL = 81
|
| 64 |
|
| 65 |
-
# Prompt negativo padrão em inglês
|
| 66 |
default_negative_prompt = (
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
"poorly drawn face, malformed limbs, fused fingers, messy background, three legs, "
|
| 71 |
-
"too many people, walking backwards."
|
| 72 |
)
|
| 73 |
|
| 74 |
def __init__(self) -> None:
|
|
|
|
| 75 |
self._print_env_banner()
|
|
|
|
| 76 |
print("Loading models into memory. This may take a few minutes...")
|
| 77 |
|
|
|
|
| 78 |
n_gpus = torch.cuda.device_count()
|
| 79 |
-
max_memory = {i: "
|
| 80 |
max_memory["cpu"] = "120GiB"
|
| 81 |
|
| 82 |
transformer = WanTransformer3DModel.from_pretrained(
|
| 83 |
-
self.TRANSFORMER_ID,
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
)
|
| 86 |
transformer_2 = WanTransformer3DModel.from_pretrained(
|
| 87 |
-
self.TRANSFORMER_ID,
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
| 89 |
)
|
|
|
|
| 90 |
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
| 91 |
-
self.MODEL_ID,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
)
|
| 93 |
-
self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.pipe.scheduler.config, shift=32.0)
|
| 94 |
|
|
|
|
| 95 |
print("Applying 8-step Lightning LoRA...")
|
| 96 |
try:
|
| 97 |
-
self.pipe.load_lora_weights(
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
|
|
|
|
| 100 |
print("Fusing LoRA weights into the main model...")
|
| 101 |
self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
|
| 102 |
self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
|
| 103 |
self.pipe.unload_lora_weights()
|
| 104 |
-
print("Lightning LoRA successfully fused.")
|
| 105 |
except Exception as e:
|
| 106 |
-
print(f"[WanManager] AVISO: Falha ao fundir LoRA Lightning: {e}")
|
| 107 |
|
| 108 |
print("All models loaded. Service is ready.")
|
| 109 |
|
|
|
|
| 110 |
def _print_env_banner(self) -> None:
|
| 111 |
def _safe_get(fn, default="n/a"):
|
| 112 |
-
try:
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
|
| 115 |
torch_ver = getattr(torch, "__version__", "unknown")
|
| 116 |
cuda_rt = getattr(torch.version, "cuda", "unknown")
|
|
@@ -124,17 +141,33 @@ class WanManager:
|
|
| 124 |
devs.append(f"cuda:{i} {props.name}")
|
| 125 |
total_vram.append(f"{props.total_memory/1024**3:.1f}GiB")
|
| 126 |
caps.append(f"{props.major}.{props.minor}")
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
except: bf16_supported = False
|
| 130 |
-
|
| 131 |
-
tf32_allowed = torch.backends.cuda.matmul.allow_tf32
|
| 132 |
-
sdpa_api = "torch.nn.attention (2.1+)" if _SDPA_NEW else "torch.backends.cuda (2.0)" if not _SDPA_NEW and hasattr(torch.backends.cuda, 'sdp_kernel') else "unavailable"
|
| 133 |
-
|
| 134 |
try:
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
xformers_ok = True
|
| 137 |
-
except
|
| 138 |
xformers_ok = False
|
| 139 |
|
| 140 |
alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "unset")
|
|
@@ -143,43 +176,62 @@ class WanManager:
|
|
| 143 |
nvcc = shutil.which("nvcc")
|
| 144 |
nvcc_ver = "n/a"
|
| 145 |
if nvcc:
|
| 146 |
-
try:
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
banner_lines = [
|
| 150 |
"================== WAN MANAGER • ENV ==================",
|
| 151 |
-
f"Python : {python_ver}",
|
| 152 |
-
f"
|
| 153 |
-
f"CUDA
|
|
|
|
|
|
|
|
|
|
| 154 |
f"GPUs : {', '.join(devs) if devs else 'n/a'}",
|
| 155 |
f"GPU VRAM : {', '.join(total_vram) if total_vram else 'n/a'}",
|
| 156 |
f"Compute Capability : {', '.join(caps) if caps else 'n/a'}",
|
| 157 |
-
f"BF16 supported : {bf16_supported}",
|
| 158 |
-
f"
|
| 159 |
-
f"
|
|
|
|
|
|
|
|
|
|
| 160 |
f"nvcc : {nvcc_ver}",
|
| 161 |
"=======================================================",
|
| 162 |
]
|
| 163 |
print("\n".join(banner_lines))
|
| 164 |
|
|
|
|
| 165 |
def _round_multiple(self, x: int, multiple: int) -> int:
|
| 166 |
return int(round(x / multiple) * multiple)
|
| 167 |
|
| 168 |
def process_image_for_video(self, image: Image.Image) -> Image.Image:
|
| 169 |
w, h = image.size
|
| 170 |
-
if w == h:
|
|
|
|
|
|
|
| 171 |
ar = w / h
|
| 172 |
nw, nh = w, h
|
|
|
|
|
|
|
| 173 |
if nw > self.MAX_DIMENSION or nh > self.MAX_DIMENSION:
|
| 174 |
s = (self.MAX_DIMENSION / nw) if ar > 1 else (self.MAX_DIMENSION / nh)
|
| 175 |
nw, nh = nw * s, nh * s
|
|
|
|
|
|
|
| 176 |
if nw < self.MIN_DIMENSION or nh < self.MIN_DIMENSION:
|
| 177 |
s = (self.MIN_DIMENSION / nh) if ar > 1 else (self.MIN_DIMENSION / nw)
|
| 178 |
nw, nh = nw * s, nh * s
|
|
|
|
| 179 |
fw = self._round_multiple(int(nw), self.DIMENSION_MULTIPLE)
|
| 180 |
fh = self._round_multiple(int(nh), self.DIMENSION_MULTIPLE)
|
|
|
|
|
|
|
| 181 |
fw = max(fw, self.MIN_DIMENSION if ar < 1 else self.SQUARE_SIZE)
|
| 182 |
fh = max(fh, self.MIN_DIMENSION if ar > 1 else self.SQUARE_SIZE)
|
|
|
|
| 183 |
return image.resize((fw, fh), Image.Resampling.LANCZOS)
|
| 184 |
|
| 185 |
def resize_and_crop_to_match(self, target: Image.Image, ref: Image.Image) -> Image.Image:
|
|
@@ -191,9 +243,10 @@ class WanManager:
|
|
| 191 |
left, top = (nw - rw) // 2, (nh - rh) // 2
|
| 192 |
return resized.crop((left, top, left + rw, top + rh))
|
| 193 |
|
|
|
|
| 194 |
def generate_video_from_conditions(
|
| 195 |
self,
|
| 196 |
-
images_condition_items: List[List[Any]],
|
| 197 |
prompt: str,
|
| 198 |
negative_prompt: Optional[str],
|
| 199 |
duration_seconds: float,
|
|
@@ -203,7 +256,8 @@ class WanManager:
|
|
| 203 |
seed: int,
|
| 204 |
randomize_seed: bool,
|
| 205 |
output_type: str = "np",
|
| 206 |
-
) -> Tuple[str, int
|
|
|
|
| 207 |
if not images_condition_items or len(images_condition_items) < 2:
|
| 208 |
raise ValueError("Forneça ao menos dois itens (início e fim).")
|
| 209 |
|
|
@@ -212,82 +266,74 @@ class WanManager:
|
|
| 212 |
end_image = items[-1][0]
|
| 213 |
if start_image is None or end_image is None:
|
| 214 |
raise ValueError("As imagens inicial e final não podem ser vazias.")
|
|
|
|
|
|
|
| 215 |
|
|
|
|
| 216 |
handle_image = items[1][0] if len(items) >= 3 else None
|
|
|
|
|
|
|
| 217 |
handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
|
| 218 |
end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
|
| 219 |
|
|
|
|
| 220 |
processed_start = self.process_image_for_video(start_image)
|
| 221 |
processed_end = self.resize_and_crop_to_match(end_image, processed_start)
|
| 222 |
processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image else None
|
| 223 |
|
| 224 |
H, W = processed_start.height, processed_start.width
|
| 225 |
-
|
| 226 |
-
# 1. Calcula e valida o número total de frames
|
| 227 |
-
initial_frames = int(round(duration_seconds * self.FIXED_FPS))
|
| 228 |
-
clamped_frames = int(np.clip(initial_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
|
| 229 |
-
sf_t = getattr(self.pipe, "vae_scale_factor_temporal", 4)
|
| 230 |
-
num_frames = ((clamped_frames - 1) // sf_t * sf_t) + 1 # Garante o formato 4n+1
|
| 231 |
-
|
| 232 |
-
print(f"[WanManager] INFO: Duração {duration_seconds}s => {initial_frames} frames. "
|
| 233 |
-
f"Após clamp e alinhamento 4n+1, o total de frames final é {num_frames}.")
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
|
| 236 |
generator = torch.Generator().manual_seed(current_seed)
|
| 237 |
|
| 238 |
-
|
| 239 |
-
callback_kwargs = {"callback_on_step_end": denoise_callback, "callback_on_step_end_tensor_inputs": ["latents"]}
|
| 240 |
-
|
| 241 |
call_kwargs = dict(
|
| 242 |
-
image=processed_start,
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
)
|
| 246 |
|
| 247 |
-
#
|
| 248 |
-
corrected_handle_index =
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
block_index = round(handle_frame_ui / 8)
|
| 253 |
-
aligned_frame = (block_index * 8 )+ 1
|
| 254 |
-
|
| 255 |
-
min_safe_frame = 9 # Buffer de 8 frames no início (1*8 + 1)
|
| 256 |
-
max_safe_frame = num_frames - 9 # Buffer de 8 frames no fim
|
| 257 |
-
|
| 258 |
-
corrected_handle_index = max(min_safe_frame, min(aligned_frame, max_safe_frame))
|
| 259 |
-
|
| 260 |
-
print(f"[WanManager] INFO: Handle Frame UI {handle_frame_ui} alinhado para {aligned_frame} e validado para {corrected_handle_index} (limites seguros: {min_safe_frame}-{max_safe_frame}).")
|
| 261 |
-
|
| 262 |
-
base_kwargs = {**call_kwargs, "anchor_weight_last": float(end_weight)}
|
| 263 |
if processed_handle is not None:
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
result = None
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
result = self.pipe(**base_kwargs)
|
| 274 |
frames = result.frames[0]
|
| 275 |
|
| 276 |
-
|
| 277 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 278 |
video_path = tmp.name
|
| 279 |
export_to_video(frames, video_path, fps=self.FIXED_FPS)
|
| 280 |
|
| 281 |
-
|
| 282 |
-
if denoise_callback.intermediate_frames:
|
| 283 |
-
with tempfile.NamedTemporaryFile(suffix="_denoise_process.mp4", delete=False) as tmp:
|
| 284 |
-
debug_video_path = tmp.name
|
| 285 |
-
denoise_callback.save_as_video(debug_video_path, fps=max(1, steps // 2))
|
| 286 |
-
|
| 287 |
-
grid_pil = denoise_callback.create_steps_grid()
|
| 288 |
-
if grid_pil:
|
| 289 |
-
with tempfile.NamedTemporaryFile(suffix="_steps_grid.png", delete=False) as tmp:
|
| 290 |
-
grid_image_path = tmp.name
|
| 291 |
-
grid_pil.save(grid_image_path)
|
| 292 |
-
|
| 293 |
-
return video_path, current_seed, debug_video_path, grid_image_path
|
|
|
|
| 1 |
# aduc_framework/managers/wan_manager.py
|
| 2 |
+
# WanManager v0.1.4 (final)
|
| 3 |
|
| 4 |
import os
|
| 5 |
import platform
|
|
|
|
| 13 |
import torch
|
| 14 |
from PIL import Image
|
| 15 |
|
| 16 |
+
# SDPA / FlashAttention context (PyTorch 2.1+ / 2.0 fallback)
|
|
|
|
|
|
|
|
|
|
| 17 |
try:
|
| 18 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend # PyTorch 2.1+
|
| 19 |
_SDPA_NEW = True
|
| 20 |
except Exception:
|
| 21 |
+
from torch.backends.cuda import sdp_kernel as _legacy_sdp # PyTorch 2.0
|
| 22 |
_SDPA_NEW = False
|
| 23 |
|
| 24 |
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
|
|
| 26 |
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
|
| 27 |
from diffusers.utils.export_utils import export_to_video
|
| 28 |
|
|
|
|
|
|
|
| 29 |
|
| 30 |
class WanManager:
|
| 31 |
"""
|
| 32 |
+
Wan i2v Manager:
|
| 33 |
+
- Banner com verificações PyTorch/CUDA/SDPA/GPUs no startup
|
| 34 |
+
- 2 Transformers 3D (alto/baixo ruído), bf16, device_map='auto', max_memory por GPU
|
| 35 |
+
- LoRA Lightning fundida e descarregada
|
| 36 |
+
- SDPA com preferência por FlashAttention + fallback (efficient/math)
|
| 37 |
+
- 3 batentes: image(t=0, peso 1), handle(k da UI alinhado a 1 (mod 4)), last(t final)
|
| 38 |
+
- Fallback se a pipeline não suportar args customizados (handle/anchor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 42 |
TRANSFORMER_ID = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers"
|
| 43 |
|
| 44 |
+
# Dimensões/frames
|
| 45 |
MAX_DIMENSION = 832
|
| 46 |
MIN_DIMENSION = 480
|
| 47 |
DIMENSION_MULTIPLE = 16
|
| 48 |
SQUARE_SIZE = 480
|
| 49 |
+
|
| 50 |
FIXED_FPS = 16
|
| 51 |
MIN_FRAMES_MODEL = 8
|
| 52 |
MAX_FRAMES_MODEL = 81
|
| 53 |
|
|
|
|
| 54 |
default_negative_prompt = (
|
| 55 |
+
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,"
|
| 56 |
+
"JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,"
|
| 57 |
+
"手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
def __init__(self) -> None:
|
| 61 |
+
# Banner de verificação
|
| 62 |
self._print_env_banner()
|
| 63 |
+
|
| 64 |
print("Loading models into memory. This may take a few minutes...")
|
| 65 |
|
| 66 |
+
# Sharding automático com chaves válidas (inteiros e "cpu")
|
| 67 |
n_gpus = torch.cuda.device_count()
|
| 68 |
+
max_memory = {i: "45GiB" for i in range(n_gpus)} # ajuste conforme VRAM
|
| 69 |
max_memory["cpu"] = "120GiB"
|
| 70 |
|
| 71 |
transformer = WanTransformer3DModel.from_pretrained(
|
| 72 |
+
self.TRANSFORMER_ID,
|
| 73 |
+
subfolder="transformer",
|
| 74 |
+
torch_dtype=torch.bfloat16,
|
| 75 |
+
device_map="auto",
|
| 76 |
+
max_memory=max_memory,
|
| 77 |
)
|
| 78 |
transformer_2 = WanTransformer3DModel.from_pretrained(
|
| 79 |
+
self.TRANSFORMER_ID,
|
| 80 |
+
subfolder="transformer_2",
|
| 81 |
+
torch_dtype=torch.bfloat16,
|
| 82 |
+
device_map="auto",
|
| 83 |
+
max_memory=max_memory,
|
| 84 |
)
|
| 85 |
+
|
| 86 |
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
| 87 |
+
self.MODEL_ID,
|
| 88 |
+
transformer=transformer,
|
| 89 |
+
transformer_2=transformer_2,
|
| 90 |
+
torch_dtype=torch.bfloat16,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Scheduler FlowMatch Euler (shift=32.0)
|
| 94 |
+
self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
| 95 |
+
self.pipe.scheduler.config, shift=32.0
|
| 96 |
)
|
|
|
|
| 97 |
|
| 98 |
+
# LoRA Lightning (fusão)
|
| 99 |
print("Applying 8-step Lightning LoRA...")
|
| 100 |
try:
|
| 101 |
+
self.pipe.load_lora_weights(
|
| 102 |
+
"Kijai/WanVideo_comfy",
|
| 103 |
+
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
|
| 104 |
+
adapter_name="lightx2v",
|
| 105 |
+
)
|
| 106 |
+
self.pipe.load_lora_weights(
|
| 107 |
+
"Kijai/WanVideo_comfy",
|
| 108 |
+
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
|
| 109 |
+
adapter_name="lightx2v_2",
|
| 110 |
+
load_into_transformer_2=True,
|
| 111 |
+
)
|
| 112 |
self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
|
| 113 |
+
|
| 114 |
print("Fusing LoRA weights into the main model...")
|
| 115 |
self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
|
| 116 |
self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
|
| 117 |
self.pipe.unload_lora_weights()
|
| 118 |
+
print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
|
| 119 |
except Exception as e:
|
| 120 |
+
print(f"[WanManager] AVISO: Falha ao fundir LoRA Lightning (seguirá sem fusão): {e}")
|
| 121 |
|
| 122 |
print("All models loaded. Service is ready.")
|
| 123 |
|
| 124 |
+
# ---------- Banner/Checks ----------
|
| 125 |
def _print_env_banner(self) -> None:
|
| 126 |
def _safe_get(fn, default="n/a"):
|
| 127 |
+
try:
|
| 128 |
+
return fn()
|
| 129 |
+
except Exception:
|
| 130 |
+
return default
|
| 131 |
|
| 132 |
torch_ver = getattr(torch, "__version__", "unknown")
|
| 133 |
cuda_rt = getattr(torch.version, "cuda", "unknown")
|
|
|
|
| 141 |
devs.append(f"cuda:{i} {props.name}")
|
| 142 |
total_vram.append(f"{props.total_memory/1024**3:.1f}GiB")
|
| 143 |
caps.append(f"{props.major}.{props.minor}")
|
| 144 |
+
|
| 145 |
+
# BF16/TF32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
try:
|
| 147 |
+
bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
|
| 148 |
+
except Exception:
|
| 149 |
+
bf16_supported = False
|
| 150 |
+
if cuda_ok and caps:
|
| 151 |
+
major = int(caps[0].split(".")[0])
|
| 152 |
+
bf16_supported = major >= 8
|
| 153 |
+
tf32_allowed = getattr(torch.backends.cuda.matmul, "allow_tf32", False)
|
| 154 |
+
|
| 155 |
+
# SDPA API
|
| 156 |
+
try:
|
| 157 |
+
from torch.nn.attention import sdpa_kernel as _probe1 # noqa
|
| 158 |
+
sdpa_api = "torch.nn.attention (2.1+)"
|
| 159 |
+
except Exception:
|
| 160 |
+
try:
|
| 161 |
+
from torch.backends.cuda import sdp_kernel as _probe2 # noqa
|
| 162 |
+
sdpa_api = "torch.backends.cuda (2.0)"
|
| 163 |
+
except Exception:
|
| 164 |
+
sdpa_api = "unavailable"
|
| 165 |
+
|
| 166 |
+
# xFormers
|
| 167 |
+
try:
|
| 168 |
+
import xformers # noqa
|
| 169 |
xformers_ok = True
|
| 170 |
+
except Exception:
|
| 171 |
xformers_ok = False
|
| 172 |
|
| 173 |
alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "unset")
|
|
|
|
| 176 |
nvcc = shutil.which("nvcc")
|
| 177 |
nvcc_ver = "n/a"
|
| 178 |
if nvcc:
|
| 179 |
+
try:
|
| 180 |
+
nvcc_ver = subprocess.check_output([nvcc, "--version"], text=True).strip().splitlines()[-1]
|
| 181 |
+
except Exception:
|
| 182 |
+
nvcc_ver = "n/a"
|
| 183 |
|
| 184 |
banner_lines = [
|
| 185 |
"================== WAN MANAGER • ENV ==================",
|
| 186 |
+
f"Python : {python_ver}",
|
| 187 |
+
f"PyTorch : {torch_ver}",
|
| 188 |
+
f"CUDA (torch) : {cuda_rt}",
|
| 189 |
+
f"cuDNN : {cudnn_ver}",
|
| 190 |
+
f"CUDA available : {cuda_ok}",
|
| 191 |
+
f"GPU count : {n_gpu}",
|
| 192 |
f"GPUs : {', '.join(devs) if devs else 'n/a'}",
|
| 193 |
f"GPU VRAM : {', '.join(total_vram) if total_vram else 'n/a'}",
|
| 194 |
f"Compute Capability : {', '.join(caps) if caps else 'n/a'}",
|
| 195 |
+
f"BF16 supported : {bf16_supported}",
|
| 196 |
+
f"TF32 allowed : {tf32_allowed}",
|
| 197 |
+
f"SDPA API : {sdpa_api}",
|
| 198 |
+
f"xFormers available : {xformers_ok}",
|
| 199 |
+
f"CUDA_VISIBLE_DEVICES: {visible}",
|
| 200 |
+
f"PYTORCH_CUDA_ALLOC_CONF: {alloc_conf}",
|
| 201 |
f"nvcc : {nvcc_ver}",
|
| 202 |
"=======================================================",
|
| 203 |
]
|
| 204 |
print("\n".join(banner_lines))
|
| 205 |
|
| 206 |
+
# ---------- utils de imagem ----------
|
| 207 |
def _round_multiple(self, x: int, multiple: int) -> int:
|
| 208 |
return int(round(x / multiple) * multiple)
|
| 209 |
|
| 210 |
def process_image_for_video(self, image: Image.Image) -> Image.Image:
|
| 211 |
w, h = image.size
|
| 212 |
+
if w == h:
|
| 213 |
+
return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS)
|
| 214 |
+
|
| 215 |
ar = w / h
|
| 216 |
nw, nh = w, h
|
| 217 |
+
|
| 218 |
+
# clamp superior
|
| 219 |
if nw > self.MAX_DIMENSION or nh > self.MAX_DIMENSION:
|
| 220 |
s = (self.MAX_DIMENSION / nw) if ar > 1 else (self.MAX_DIMENSION / nh)
|
| 221 |
nw, nh = nw * s, nh * s
|
| 222 |
+
|
| 223 |
+
# clamp inferior
|
| 224 |
if nw < self.MIN_DIMENSION or nh < self.MIN_DIMENSION:
|
| 225 |
s = (self.MIN_DIMENSION / nh) if ar > 1 else (self.MIN_DIMENSION / nw)
|
| 226 |
nw, nh = nw * s, nh * s
|
| 227 |
+
|
| 228 |
fw = self._round_multiple(int(nw), self.DIMENSION_MULTIPLE)
|
| 229 |
fh = self._round_multiple(int(nh), self.DIMENSION_MULTIPLE)
|
| 230 |
+
|
| 231 |
+
# mínimos finais coerentes
|
| 232 |
fw = max(fw, self.MIN_DIMENSION if ar < 1 else self.SQUARE_SIZE)
|
| 233 |
fh = max(fh, self.MIN_DIMENSION if ar > 1 else self.SQUARE_SIZE)
|
| 234 |
+
|
| 235 |
return image.resize((fw, fh), Image.Resampling.LANCZOS)
|
| 236 |
|
| 237 |
def resize_and_crop_to_match(self, target: Image.Image, ref: Image.Image) -> Image.Image:
|
|
|
|
| 243 |
left, top = (nw - rw) // 2, (nh - rh) // 2
|
| 244 |
return resized.crop((left, top, left + rw, top + rh))
|
| 245 |
|
| 246 |
+
# ---------- API ----------
|
| 247 |
def generate_video_from_conditions(
|
| 248 |
self,
|
| 249 |
+
images_condition_items: List[List[Any]], # [[image(Image), frame(int|str), peso(float)], ...]
|
| 250 |
prompt: str,
|
| 251 |
negative_prompt: Optional[str],
|
| 252 |
duration_seconds: float,
|
|
|
|
| 256 |
seed: int,
|
| 257 |
randomize_seed: bool,
|
| 258 |
output_type: str = "np",
|
| 259 |
+
) -> Tuple[str, int]:
|
| 260 |
+
# validação
|
| 261 |
if not images_condition_items or len(images_condition_items) < 2:
|
| 262 |
raise ValueError("Forneça ao menos dois itens (início e fim).")
|
| 263 |
|
|
|
|
| 266 |
end_image = items[-1][0]
|
| 267 |
if start_image is None or end_image is None:
|
| 268 |
raise ValueError("As imagens inicial e final não podem ser vazias.")
|
| 269 |
+
if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
|
| 270 |
+
raise TypeError("Patches devem ser PIL.Image.")
|
| 271 |
|
| 272 |
+
# handle opcional
|
| 273 |
handle_image = items[1][0] if len(items) >= 3 else None
|
| 274 |
+
|
| 275 |
+
# pesos
|
| 276 |
handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
|
| 277 |
end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
|
| 278 |
|
| 279 |
+
# preprocess e alinhamento HxW
|
| 280 |
processed_start = self.process_image_for_video(start_image)
|
| 281 |
processed_end = self.resize_and_crop_to_match(end_image, processed_start)
|
| 282 |
processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image else None
|
| 283 |
|
| 284 |
H, W = processed_start.height, processed_start.width
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
+
# frames (pipeline ajusta para 4n+1 internamente, aqui só clamp)
|
| 287 |
+
num_frames = int(round(duration_seconds * self.FIXED_FPS))
|
| 288 |
+
num_frames = int(np.clip(num_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
|
| 289 |
+
|
| 290 |
+
# seed
|
| 291 |
current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
|
| 292 |
generator = torch.Generator().manual_seed(current_seed)
|
| 293 |
|
| 294 |
+
# argumentos base
|
|
|
|
|
|
|
| 295 |
call_kwargs = dict(
|
| 296 |
+
image=processed_start,
|
| 297 |
+
last_image=processed_end,
|
| 298 |
+
prompt=prompt,
|
| 299 |
+
negative_prompt=negative_prompt if negative_prompt else self.default_negative_prompt,
|
| 300 |
+
height=H,
|
| 301 |
+
width=W,
|
| 302 |
+
num_frames=num_frames,
|
| 303 |
+
guidance_scale=float(guidance_scale),
|
| 304 |
+
guidance_scale_2=float(guidance_scale_2),
|
| 305 |
+
num_inference_steps=int(steps),
|
| 306 |
+
generator=generator,
|
| 307 |
+
output_type=output_type,
|
| 308 |
)
|
| 309 |
|
| 310 |
+
# mapear frame da UI do handle → índice latente alinhado a 1 (mod 4)
|
| 311 |
+
corrected_handle_index = int(items[1][1])
|
| 312 |
+
|
| 313 |
+
# Montar kwargs finais (com/sem handle)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
if processed_handle is not None:
|
| 315 |
+
kwargs = dict(
|
| 316 |
+
**call_kwargs,
|
| 317 |
+
handle_image=processed_handle,
|
| 318 |
+
handle_weight=float(handle_weight),
|
| 319 |
+
handle_latent_index=corrected_handle_index,
|
| 320 |
+
anchor_weight_last=float(end_weight),
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
kwargs = dict(
|
| 324 |
+
**call_kwargs,
|
| 325 |
+
anchor_weight_last=float(end_weight),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Execução com SDPA e fallback de backend
|
| 329 |
result = None
|
| 330 |
+
|
| 331 |
+
result = self.pipe(**kwargs)
|
| 332 |
|
|
|
|
| 333 |
frames = result.frames[0]
|
| 334 |
|
|
|
|
| 335 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 336 |
video_path = tmp.name
|
| 337 |
export_to_video(frames, video_path, fps=self.FIXED_FPS)
|
| 338 |
|
| 339 |
+
return video_path, current_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|