Spaces:
Paused
Paused
| # aduc_framework/managers/wan_manager.py | |
| # WanManager v0.1.4 (final) | |
| import os | |
| import platform | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import random | |
| from typing import List, Any, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| # SDPA / FlashAttention context (PyTorch 2.1+ / 2.0 fallback) | |
| try: | |
| from torch.nn.attention import sdpa_kernel, SDPBackend # PyTorch 2.1+ | |
| _SDPA_NEW = True | |
| except Exception: | |
| from torch.backends.cuda import sdp_kernel as _legacy_sdp # PyTorch 2.0 | |
| _SDPA_NEW = False | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.utils.export_utils import export_to_video | |
| class WanManager: | |
| """ | |
| Wan i2v Manager: | |
| - Banner com verificações PyTorch/CUDA/SDPA/GPUs no startup | |
| - 2 Transformers 3D (alto/baixo ruído), bf16, device_map='auto', max_memory por GPU | |
| - LoRA Lightning fundida e descarregada | |
| - SDPA com preferência por FlashAttention + fallback (efficient/math) | |
| - 3 batentes: image(t=0, peso 1), handle(k da UI alinhado a 1 (mod 4)), last(t final) | |
| - Fallback se a pipeline não suportar args customizados (handle/anchor) | |
| """ | |
| MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| TRANSFORMER_ID = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers" | |
| # Dimensões/frames | |
| MAX_DIMENSION = 832 | |
| MIN_DIMENSION = 480 | |
| DIMENSION_MULTIPLE = 16 | |
| SQUARE_SIZE = 480 | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 81 | |
| default_negative_prompt = ( | |
| "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量," | |
| "JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体," | |
| "手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," | |
| ) | |
| def __init__(self) -> None: | |
| # Banner de verificação | |
| self._print_env_banner() | |
| print("Loading models into memory. This may take a few minutes...") | |
| # Sharding automático com chaves válidas (inteiros e "cpu") | |
| n_gpus = torch.cuda.device_count() | |
| max_memory = {i: "45GiB" for i in range(n_gpus)} # ajuste conforme VRAM | |
| max_memory["cpu"] = "120GiB" | |
| transformer = WanTransformer3DModel.from_pretrained( | |
| self.TRANSFORMER_ID, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| max_memory=max_memory, | |
| ) | |
| transformer_2 = WanTransformer3DModel.from_pretrained( | |
| self.TRANSFORMER_ID, | |
| subfolder="transformer_2", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| max_memory=max_memory, | |
| ) | |
| self.pipe = WanImageToVideoPipeline.from_pretrained( | |
| self.MODEL_ID, | |
| transformer=transformer, | |
| transformer_2=transformer_2, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Scheduler FlowMatch Euler (shift=32.0) | |
| self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( | |
| self.pipe.scheduler.config, shift=32.0 | |
| ) | |
| # LoRA Lightning (fusão) | |
| print("Applying 8-step Lightning LoRA...") | |
| try: | |
| self.pipe.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v", | |
| ) | |
| self.pipe.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v_2", | |
| load_into_transformer_2=True, | |
| ) | |
| self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0]) | |
| print("Fusing LoRA weights into the main model...") | |
| self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"]) | |
| self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"]) | |
| self.pipe.unload_lora_weights() | |
| print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.") | |
| except Exception as e: | |
| print(f"[WanManager] AVISO: Falha ao fundir LoRA Lightning (seguirá sem fusão): {e}") | |
| print("All models loaded. Service is ready.") | |
| # ---------- Banner/Checks ---------- | |
| def _print_env_banner(self) -> None: | |
| def _safe_get(fn, default="n/a"): | |
| try: | |
| return fn() | |
| except Exception: | |
| return default | |
| torch_ver = getattr(torch, "__version__", "unknown") | |
| cuda_rt = getattr(torch.version, "cuda", "unknown") | |
| cudnn_ver = _safe_get(lambda: torch.backends.cudnn.version()) | |
| cuda_ok = torch.cuda.is_available() | |
| n_gpu = torch.cuda.device_count() if cuda_ok else 0 | |
| devs, total_vram, caps = [], [], [] | |
| if cuda_ok: | |
| for i in range(n_gpu): | |
| props = torch.cuda.get_device_properties(i) | |
| devs.append(f"cuda:{i} {props.name}") | |
| total_vram.append(f"{props.total_memory/1024**3:.1f}GiB") | |
| caps.append(f"{props.major}.{props.minor}") | |
| # BF16/TF32 | |
| try: | |
| bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) | |
| except Exception: | |
| bf16_supported = False | |
| if cuda_ok and caps: | |
| major = int(caps[0].split(".")[0]) | |
| bf16_supported = major >= 8 | |
| tf32_allowed = getattr(torch.backends.cuda.matmul, "allow_tf32", False) | |
| # SDPA API | |
| try: | |
| from torch.nn.attention import sdpa_kernel as _probe1 # noqa | |
| sdpa_api = "torch.nn.attention (2.1+)" | |
| except Exception: | |
| try: | |
| from torch.backends.cuda import sdp_kernel as _probe2 # noqa | |
| sdpa_api = "torch.backends.cuda (2.0)" | |
| except Exception: | |
| sdpa_api = "unavailable" | |
| # xFormers | |
| try: | |
| import xformers # noqa | |
| xformers_ok = True | |
| except Exception: | |
| xformers_ok = False | |
| alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "unset") | |
| visible = os.environ.get("CUDA_VISIBLE_DEVICES", "unset") | |
| python_ver = platform.python_version() | |
| nvcc = shutil.which("nvcc") | |
| nvcc_ver = "n/a" | |
| if nvcc: | |
| try: | |
| nvcc_ver = subprocess.check_output([nvcc, "--version"], text=True).strip().splitlines()[-1] | |
| except Exception: | |
| nvcc_ver = "n/a" | |
| banner_lines = [ | |
| "================== WAN MANAGER • ENV ==================", | |
| f"Python : {python_ver}", | |
| f"PyTorch : {torch_ver}", | |
| f"CUDA (torch) : {cuda_rt}", | |
| f"cuDNN : {cudnn_ver}", | |
| f"CUDA available : {cuda_ok}", | |
| f"GPU count : {n_gpu}", | |
| f"GPUs : {', '.join(devs) if devs else 'n/a'}", | |
| f"GPU VRAM : {', '.join(total_vram) if total_vram else 'n/a'}", | |
| f"Compute Capability : {', '.join(caps) if caps else 'n/a'}", | |
| f"BF16 supported : {bf16_supported}", | |
| f"TF32 allowed : {tf32_allowed}", | |
| f"SDPA API : {sdpa_api}", | |
| f"xFormers available : {xformers_ok}", | |
| f"CUDA_VISIBLE_DEVICES: {visible}", | |
| f"PYTORCH_CUDA_ALLOC_CONF: {alloc_conf}", | |
| f"nvcc : {nvcc_ver}", | |
| "=======================================================", | |
| ] | |
| print("\n".join(banner_lines)) | |
| # ---------- utils de imagem ---------- | |
| def _round_multiple(self, x: int, multiple: int) -> int: | |
| return int(round(x / multiple) * multiple) | |
| def process_image_for_video(self, image: Image.Image) -> Image.Image: | |
| w, h = image.size | |
| if w == h: | |
| return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS) | |
| ar = w / h | |
| nw, nh = w, h | |
| # clamp superior | |
| if nw > self.MAX_DIMENSION or nh > self.MAX_DIMENSION: | |
| s = (self.MAX_DIMENSION / nw) if ar > 1 else (self.MAX_DIMENSION / nh) | |
| nw, nh = nw * s, nh * s | |
| # clamp inferior | |
| if nw < self.MIN_DIMENSION or nh < self.MIN_DIMENSION: | |
| s = (self.MIN_DIMENSION / nh) if ar > 1 else (self.MIN_DIMENSION / nw) | |
| nw, nh = nw * s, nh * s | |
| fw = self._round_multiple(int(nw), self.DIMENSION_MULTIPLE) | |
| fh = self._round_multiple(int(nh), self.DIMENSION_MULTIPLE) | |
| # mínimos finais coerentes | |
| fw = max(fw, self.MIN_DIMENSION if ar < 1 else self.SQUARE_SIZE) | |
| fh = max(fh, self.MIN_DIMENSION if ar > 1 else self.SQUARE_SIZE) | |
| return image.resize((fw, fh), Image.Resampling.LANCZOS) | |
| def resize_and_crop_to_match(self, target: Image.Image, ref: Image.Image) -> Image.Image: | |
| rw, rh = ref.size | |
| tw, th = target.size | |
| s = max(rw / tw, rh / th) | |
| nw, nh = int(tw * s), int(th * s) | |
| resized = target.resize((nw, nh), Image.Resampling.LANCZOS) | |
| left, top = (nw - rw) // 2, (nh - rh) // 2 | |
| return resized.crop((left, top, left + rw, top + rh)) | |
| # ---------- API ---------- | |
| def generate_video_from_conditions( | |
| self, | |
| images_condition_items: List[List[Any]], # [[image(Image), frame(int|str), peso(float)], ...] | |
| prompt: str, | |
| negative_prompt: Optional[str], | |
| duration_seconds: float, | |
| steps: int, | |
| guidance_scale: float, | |
| guidance_scale_2: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| output_type: str = "np", | |
| ) -> Tuple[str, int]: | |
| # validação | |
| if not images_condition_items or len(images_condition_items) < 2: | |
| raise ValueError("Forneça ao menos dois itens (início e fim).") | |
| items = images_condition_items | |
| start_image = items[0][0] | |
| end_image = items[-1][0] | |
| if start_image is None or end_image is None: | |
| raise ValueError("As imagens inicial e final não podem ser vazias.") | |
| if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image): | |
| raise TypeError("Patches devem ser PIL.Image.") | |
| # handle opcional | |
| handle_image = items[1][0] if len(items) >= 3 else None | |
| # pesos | |
| handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0 | |
| end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0 | |
| # preprocess e alinhamento HxW | |
| processed_start = self.process_image_for_video(start_image) | |
| processed_end = self.resize_and_crop_to_match(end_image, processed_start) | |
| processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image else None | |
| H, W = processed_start.height, processed_start.width | |
| # frames (pipeline ajusta para 4n+1 internamente, aqui só clamp) | |
| num_frames = int(round(duration_seconds * self.FIXED_FPS)) | |
| num_frames = int(np.clip(num_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL)) | |
| # seed | |
| current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed) | |
| generator = torch.Generator().manual_seed(current_seed) | |
| # argumentos base | |
| call_kwargs = dict( | |
| image=processed_start, | |
| last_image=processed_end, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else self.default_negative_prompt, | |
| height=H, | |
| width=W, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| output_type=output_type, | |
| ) | |
| # mapear frame da UI do handle → índice latente alinhado a 1 (mod 4) | |
| corrected_handle_index = int(items[1][1]) | |
| # Montar kwargs finais (com/sem handle) | |
| if processed_handle is not None: | |
| kwargs = dict( | |
| **call_kwargs, | |
| handle_image=processed_handle, | |
| handle_weight=float(handle_weight), | |
| handle_latent_index=corrected_handle_index, | |
| anchor_weight_last=float(end_weight), | |
| ) | |
| else: | |
| kwargs = dict( | |
| **call_kwargs, | |
| anchor_weight_last=float(end_weight), | |
| ) | |
| # Execução com SDPA e fallback de backend | |
| result = None | |
| result = self.pipe(**kwargs) | |
| frames = result.frames[0] | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| video_path = tmp.name | |
| export_to_video(frames, video_path, fps=self.FIXED_FPS) | |
| return video_path, current_seed | |