x2XcarleX2x commited on
Commit
805147f
·
verified ·
1 Parent(s): 91c93ea

Update aduc_framework/managers/wan_manager.py

Browse files
Files changed (1) hide show
  1. aduc_framework/managers/wan_manager.py +162 -116
aduc_framework/managers/wan_manager.py CHANGED
@@ -1,5 +1,5 @@
1
  # aduc_framework/managers/wan_manager.py
2
- # WanManager v1.0.0 (production-ready)
3
 
4
  import os
5
  import platform
@@ -13,15 +13,12 @@ import numpy as np
13
  import torch
14
  from PIL import Image
15
 
16
- # Habilita TF32 para performance em GPUs Ampere+
17
- torch.backends.cuda.matmul.allow_tf32 = True
18
-
19
- # SDPA / FlashAttention context
20
  try:
21
- from torch.nn.attention import sdpa_kernel, SDPBackend
22
  _SDPA_NEW = True
23
  except Exception:
24
- from torch.backends.cuda import sdp_kernel as _legacy_sdp
25
  _SDPA_NEW = False
26
 
27
  from diffusers import FlowMatchEulerDiscreteScheduler
@@ -29,88 +26,108 @@ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
29
  from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
30
  from diffusers.utils.export_utils import export_to_video
31
 
32
- from aduc_framework.utils.callbacks import DenoiseStepLogger
33
-
34
 
35
  class WanManager:
36
  """
37
- Gerenciador de produção para a pipeline Wan 2.2 Image-to-Video.
38
-
39
- Funcionalidades Principais:
40
- - **Diagnóstico de Ambiente:** Exibe um banner detalhado no início com informações sobre
41
- PyTorch, CUDA, GPUs, e suporte a otimizações (SDPA, xFormers).
42
- - **Gerenciamento de Memória:** Distribui o modelo de forma otimizada por múltiplas
43
- GPUs, definindo limites de VRAM para evitar sobrecargas.
44
- - **Performance Otimizada:** Utiliza LoRA Lightning fundida para geração rápida e
45
- aproveita o SDPA (Scaled Dot Product Attention) com uma cadeia de fallback
46
- inteligente (Flash -> Efficient -> Math) para máxima velocidade.
47
- - **Validação de Parâmetros Robusta:** Implementa regras de negócio para validar e
48
- corrigir o número total de frames (`4n+1`) e a posição do frame de controle
49
- (`8n+1` com buffers de segurança), garantindo estabilidade e resultados previsíveis.
50
- - **Depuração Visual:** Integra um sistema de callbacks para capturar o processo de
51
- denoising, gerando um vídeo de depuração e uma grade de imagens com cada passo.
52
  """
53
 
54
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
55
  TRANSFORMER_ID = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers"
56
 
 
57
  MAX_DIMENSION = 832
58
  MIN_DIMENSION = 480
59
  DIMENSION_MULTIPLE = 16
60
  SQUARE_SIZE = 480
 
61
  FIXED_FPS = 16
62
  MIN_FRAMES_MODEL = 8
63
  MAX_FRAMES_MODEL = 81
64
 
65
- # Prompt negativo padrão em inglês
66
  default_negative_prompt = (
67
- "bright, overexposed, static, blurry details, text, subtitles, watermark, style, "
68
- "artwork, painting, still image, gray scale, worst quality, low quality, jpeg artifacts, "
69
- "ugly, deformed, disfigured, missing fingers, extra fingers, poorly drawn hands, "
70
- "poorly drawn face, malformed limbs, fused fingers, messy background, three legs, "
71
- "too many people, walking backwards."
72
  )
73
 
74
  def __init__(self) -> None:
 
75
  self._print_env_banner()
 
76
  print("Loading models into memory. This may take a few minutes...")
77
 
 
78
  n_gpus = torch.cuda.device_count()
79
- max_memory = {i: "43GiB" for i in range(n_gpus)}
80
  max_memory["cpu"] = "120GiB"
81
 
82
  transformer = WanTransformer3DModel.from_pretrained(
83
- self.TRANSFORMER_ID, subfolder="transformer", torch_dtype=torch.bfloat16,
84
- device_map="auto", max_memory=max_memory
 
 
 
85
  )
86
  transformer_2 = WanTransformer3DModel.from_pretrained(
87
- self.TRANSFORMER_ID, subfolder="transformer_2", torch_dtype=torch.bfloat16,
88
- device_map="auto", max_memory=max_memory
 
 
 
89
  )
 
90
  self.pipe = WanImageToVideoPipeline.from_pretrained(
91
- self.MODEL_ID, transformer=transformer, transformer_2=transformer_2, torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
92
  )
93
- self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.pipe.scheduler.config, shift=32.0)
94
 
 
95
  print("Applying 8-step Lightning LoRA...")
96
  try:
97
- self.pipe.load_lora_weights("Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v")
98
- self.pipe.load_lora_weights("Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v_2", load_into_transformer_2=True)
 
 
 
 
 
 
 
 
 
99
  self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
 
100
  print("Fusing LoRA weights into the main model...")
101
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
102
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
103
  self.pipe.unload_lora_weights()
104
- print("Lightning LoRA successfully fused.")
105
  except Exception as e:
106
- print(f"[WanManager] AVISO: Falha ao fundir LoRA Lightning: {e}")
107
 
108
  print("All models loaded. Service is ready.")
109
 
 
110
  def _print_env_banner(self) -> None:
111
  def _safe_get(fn, default="n/a"):
112
- try: return fn()
113
- except Exception: return default
 
 
114
 
115
  torch_ver = getattr(torch, "__version__", "unknown")
116
  cuda_rt = getattr(torch.version, "cuda", "unknown")
@@ -124,17 +141,33 @@ class WanManager:
124
  devs.append(f"cuda:{i} {props.name}")
125
  total_vram.append(f"{props.total_memory/1024**3:.1f}GiB")
126
  caps.append(f"{props.major}.{props.minor}")
127
-
128
- try: bf16_supported = torch.cuda.is_bf16_supported()
129
- except: bf16_supported = False
130
-
131
- tf32_allowed = torch.backends.cuda.matmul.allow_tf32
132
- sdpa_api = "torch.nn.attention (2.1+)" if _SDPA_NEW else "torch.backends.cuda (2.0)" if not _SDPA_NEW and hasattr(torch.backends.cuda, 'sdp_kernel') else "unavailable"
133
-
134
  try:
135
- import xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  xformers_ok = True
137
- except ImportError:
138
  xformers_ok = False
139
 
140
  alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "unset")
@@ -143,43 +176,62 @@ class WanManager:
143
  nvcc = shutil.which("nvcc")
144
  nvcc_ver = "n/a"
145
  if nvcc:
146
- try: nvcc_ver = subprocess.check_output([nvcc, "--version"], text=True).strip().splitlines()[-1]
147
- except Exception: nvcc_ver = "n/a"
 
 
148
 
149
  banner_lines = [
150
  "================== WAN MANAGER • ENV ==================",
151
- f"Python : {python_ver}", f"PyTorch : {torch_ver}",
152
- f"CUDA (torch) : {cuda_rt}", f"cuDNN : {cudnn_ver}",
153
- f"CUDA available : {cuda_ok}", f"GPU count : {n_gpu}",
 
 
 
154
  f"GPUs : {', '.join(devs) if devs else 'n/a'}",
155
  f"GPU VRAM : {', '.join(total_vram) if total_vram else 'n/a'}",
156
  f"Compute Capability : {', '.join(caps) if caps else 'n/a'}",
157
- f"BF16 supported : {bf16_supported}", f"TF32 allowed : {tf32_allowed}",
158
- f"SDPA API : {sdpa_api}", f"xFormers available : {xformers_ok}",
159
- f"CUDA_VISIBLE_DEVICES: {visible}", f"PYTORCH_CUDA_ALLOC_CONF: {alloc_conf}",
 
 
 
160
  f"nvcc : {nvcc_ver}",
161
  "=======================================================",
162
  ]
163
  print("\n".join(banner_lines))
164
 
 
165
  def _round_multiple(self, x: int, multiple: int) -> int:
166
  return int(round(x / multiple) * multiple)
167
 
168
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
169
  w, h = image.size
170
- if w == h: return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS)
 
 
171
  ar = w / h
172
  nw, nh = w, h
 
 
173
  if nw > self.MAX_DIMENSION or nh > self.MAX_DIMENSION:
174
  s = (self.MAX_DIMENSION / nw) if ar > 1 else (self.MAX_DIMENSION / nh)
175
  nw, nh = nw * s, nh * s
 
 
176
  if nw < self.MIN_DIMENSION or nh < self.MIN_DIMENSION:
177
  s = (self.MIN_DIMENSION / nh) if ar > 1 else (self.MIN_DIMENSION / nw)
178
  nw, nh = nw * s, nh * s
 
179
  fw = self._round_multiple(int(nw), self.DIMENSION_MULTIPLE)
180
  fh = self._round_multiple(int(nh), self.DIMENSION_MULTIPLE)
 
 
181
  fw = max(fw, self.MIN_DIMENSION if ar < 1 else self.SQUARE_SIZE)
182
  fh = max(fh, self.MIN_DIMENSION if ar > 1 else self.SQUARE_SIZE)
 
183
  return image.resize((fw, fh), Image.Resampling.LANCZOS)
184
 
185
  def resize_and_crop_to_match(self, target: Image.Image, ref: Image.Image) -> Image.Image:
@@ -191,9 +243,10 @@ class WanManager:
191
  left, top = (nw - rw) // 2, (nh - rh) // 2
192
  return resized.crop((left, top, left + rw, top + rh))
193
 
 
194
  def generate_video_from_conditions(
195
  self,
196
- images_condition_items: List[List[Any]],
197
  prompt: str,
198
  negative_prompt: Optional[str],
199
  duration_seconds: float,
@@ -203,7 +256,8 @@ class WanManager:
203
  seed: int,
204
  randomize_seed: bool,
205
  output_type: str = "np",
206
- ) -> Tuple[str, int, Optional[str], Optional[str]]:
 
207
  if not images_condition_items or len(images_condition_items) < 2:
208
  raise ValueError("Forneça ao menos dois itens (início e fim).")
209
 
@@ -212,82 +266,74 @@ class WanManager:
212
  end_image = items[-1][0]
213
  if start_image is None or end_image is None:
214
  raise ValueError("As imagens inicial e final não podem ser vazias.")
 
 
215
 
 
216
  handle_image = items[1][0] if len(items) >= 3 else None
 
 
217
  handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
218
  end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
219
 
 
220
  processed_start = self.process_image_for_video(start_image)
221
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
222
  processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image else None
223
 
224
  H, W = processed_start.height, processed_start.width
225
-
226
- # 1. Calcula e valida o número total de frames
227
- initial_frames = int(round(duration_seconds * self.FIXED_FPS))
228
- clamped_frames = int(np.clip(initial_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
229
- sf_t = getattr(self.pipe, "vae_scale_factor_temporal", 4)
230
- num_frames = ((clamped_frames - 1) // sf_t * sf_t) + 1 # Garante o formato 4n+1
231
-
232
- print(f"[WanManager] INFO: Duração {duration_seconds}s => {initial_frames} frames. "
233
- f"Após clamp e alinhamento 4n+1, o total de frames final é {num_frames}.")
234
 
 
 
 
 
 
235
  current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
236
  generator = torch.Generator().manual_seed(current_seed)
237
 
238
- denoise_callback = DenoiseStepLogger(self.pipe)
239
- callback_kwargs = {"callback_on_step_end": denoise_callback, "callback_on_step_end_tensor_inputs": ["latents"]}
240
-
241
  call_kwargs = dict(
242
- image=processed_start, last_image=processed_end, prompt=prompt, negative_prompt=negative_prompt or self.default_negative_prompt,
243
- height=H, width=W, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2),
244
- num_inference_steps=int(steps), generator=generator, output_type=output_type,
 
 
 
 
 
 
 
 
 
245
  )
246
 
247
- # 2. Calcula e valida o frame de controle (handle)
248
- corrected_handle_index = None
249
- if processed_handle is not None:
250
- handle_frame_ui = int(items[1][1]) if len(items) >= 3 and items[1][1] is not None else 17
251
-
252
- block_index = round(handle_frame_ui / 8)
253
- aligned_frame = (block_index * 8 )+ 1
254
-
255
- min_safe_frame = 9 # Buffer de 8 frames no início (1*8 + 1)
256
- max_safe_frame = num_frames - 9 # Buffer de 8 frames no fim
257
-
258
- corrected_handle_index = max(min_safe_frame, min(aligned_frame, max_safe_frame))
259
-
260
- print(f"[WanManager] INFO: Handle Frame UI {handle_frame_ui} alinhado para {aligned_frame} e validado para {corrected_handle_index} (limites seguros: {min_safe_frame}-{max_safe_frame}).")
261
-
262
- base_kwargs = {**call_kwargs, "anchor_weight_last": float(end_weight)}
263
  if processed_handle is not None:
264
- base_kwargs.update({
265
- "handle_image": processed_handle,
266
- "handle_weight": float(handle_weight),
267
- "handle_frame_index": corrected_handle_index,
268
- })
269
-
270
- final_kwargs = {**base_kwargs, **callback_kwargs}
 
 
 
 
 
 
 
271
  result = None
 
 
272
 
273
- result = self.pipe(**base_kwargs)
274
  frames = result.frames[0]
275
 
276
-
277
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
278
  video_path = tmp.name
279
  export_to_video(frames, video_path, fps=self.FIXED_FPS)
280
 
281
- debug_video_path, grid_image_path = None, None
282
- if denoise_callback.intermediate_frames:
283
- with tempfile.NamedTemporaryFile(suffix="_denoise_process.mp4", delete=False) as tmp:
284
- debug_video_path = tmp.name
285
- denoise_callback.save_as_video(debug_video_path, fps=max(1, steps // 2))
286
-
287
- grid_pil = denoise_callback.create_steps_grid()
288
- if grid_pil:
289
- with tempfile.NamedTemporaryFile(suffix="_steps_grid.png", delete=False) as tmp:
290
- grid_image_path = tmp.name
291
- grid_pil.save(grid_image_path)
292
-
293
- return video_path, current_seed, debug_video_path, grid_image_path
 
1
  # aduc_framework/managers/wan_manager.py
2
+ # WanManager v0.1.4 (final)
3
 
4
  import os
5
  import platform
 
13
  import torch
14
  from PIL import Image
15
 
16
+ # SDPA / FlashAttention context (PyTorch 2.1+ / 2.0 fallback)
 
 
 
17
  try:
18
+ from torch.nn.attention import sdpa_kernel, SDPBackend # PyTorch 2.1+
19
  _SDPA_NEW = True
20
  except Exception:
21
+ from torch.backends.cuda import sdp_kernel as _legacy_sdp # PyTorch 2.0
22
  _SDPA_NEW = False
23
 
24
  from diffusers import FlowMatchEulerDiscreteScheduler
 
26
  from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
27
  from diffusers.utils.export_utils import export_to_video
28
 
 
 
29
 
30
  class WanManager:
31
  """
32
+ Wan i2v Manager:
33
+ - Banner com verificações PyTorch/CUDA/SDPA/GPUs no startup
34
+ - 2 Transformers 3D (alto/baixo ruído), bf16, device_map='auto', max_memory por GPU
35
+ - LoRA Lightning fundida e descarregada
36
+ - SDPA com preferência por FlashAttention + fallback (efficient/math)
37
+ - 3 batentes: image(t=0, peso 1), handle(k da UI alinhado a 1 (mod 4)), last(t final)
38
+ - Fallback se a pipeline não suportar args customizados (handle/anchor)
 
 
 
 
 
 
 
 
39
  """
40
 
41
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
42
  TRANSFORMER_ID = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers"
43
 
44
+ # Dimensões/frames
45
  MAX_DIMENSION = 832
46
  MIN_DIMENSION = 480
47
  DIMENSION_MULTIPLE = 16
48
  SQUARE_SIZE = 480
49
+
50
  FIXED_FPS = 16
51
  MIN_FRAMES_MODEL = 8
52
  MAX_FRAMES_MODEL = 81
53
 
 
54
  default_negative_prompt = (
55
+ "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,"
56
+ "JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,"
57
+ "手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
 
 
58
  )
59
 
60
  def __init__(self) -> None:
61
+ # Banner de verificação
62
  self._print_env_banner()
63
+
64
  print("Loading models into memory. This may take a few minutes...")
65
 
66
+ # Sharding automático com chaves válidas (inteiros e "cpu")
67
  n_gpus = torch.cuda.device_count()
68
+ max_memory = {i: "45GiB" for i in range(n_gpus)} # ajuste conforme VRAM
69
  max_memory["cpu"] = "120GiB"
70
 
71
  transformer = WanTransformer3DModel.from_pretrained(
72
+ self.TRANSFORMER_ID,
73
+ subfolder="transformer",
74
+ torch_dtype=torch.bfloat16,
75
+ device_map="auto",
76
+ max_memory=max_memory,
77
  )
78
  transformer_2 = WanTransformer3DModel.from_pretrained(
79
+ self.TRANSFORMER_ID,
80
+ subfolder="transformer_2",
81
+ torch_dtype=torch.bfloat16,
82
+ device_map="auto",
83
+ max_memory=max_memory,
84
  )
85
+
86
  self.pipe = WanImageToVideoPipeline.from_pretrained(
87
+ self.MODEL_ID,
88
+ transformer=transformer,
89
+ transformer_2=transformer_2,
90
+ torch_dtype=torch.bfloat16,
91
+ )
92
+
93
+ # Scheduler FlowMatch Euler (shift=32.0)
94
+ self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
95
+ self.pipe.scheduler.config, shift=32.0
96
  )
 
97
 
98
+ # LoRA Lightning (fusão)
99
  print("Applying 8-step Lightning LoRA...")
100
  try:
101
+ self.pipe.load_lora_weights(
102
+ "Kijai/WanVideo_comfy",
103
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
104
+ adapter_name="lightx2v",
105
+ )
106
+ self.pipe.load_lora_weights(
107
+ "Kijai/WanVideo_comfy",
108
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
109
+ adapter_name="lightx2v_2",
110
+ load_into_transformer_2=True,
111
+ )
112
  self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
113
+
114
  print("Fusing LoRA weights into the main model...")
115
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
116
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
117
  self.pipe.unload_lora_weights()
118
+ print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
119
  except Exception as e:
120
+ print(f"[WanManager] AVISO: Falha ao fundir LoRA Lightning (seguirá sem fusão): {e}")
121
 
122
  print("All models loaded. Service is ready.")
123
 
124
+ # ---------- Banner/Checks ----------
125
  def _print_env_banner(self) -> None:
126
  def _safe_get(fn, default="n/a"):
127
+ try:
128
+ return fn()
129
+ except Exception:
130
+ return default
131
 
132
  torch_ver = getattr(torch, "__version__", "unknown")
133
  cuda_rt = getattr(torch.version, "cuda", "unknown")
 
141
  devs.append(f"cuda:{i} {props.name}")
142
  total_vram.append(f"{props.total_memory/1024**3:.1f}GiB")
143
  caps.append(f"{props.major}.{props.minor}")
144
+
145
+ # BF16/TF32
 
 
 
 
 
146
  try:
147
+ bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
148
+ except Exception:
149
+ bf16_supported = False
150
+ if cuda_ok and caps:
151
+ major = int(caps[0].split(".")[0])
152
+ bf16_supported = major >= 8
153
+ tf32_allowed = getattr(torch.backends.cuda.matmul, "allow_tf32", False)
154
+
155
+ # SDPA API
156
+ try:
157
+ from torch.nn.attention import sdpa_kernel as _probe1 # noqa
158
+ sdpa_api = "torch.nn.attention (2.1+)"
159
+ except Exception:
160
+ try:
161
+ from torch.backends.cuda import sdp_kernel as _probe2 # noqa
162
+ sdpa_api = "torch.backends.cuda (2.0)"
163
+ except Exception:
164
+ sdpa_api = "unavailable"
165
+
166
+ # xFormers
167
+ try:
168
+ import xformers # noqa
169
  xformers_ok = True
170
+ except Exception:
171
  xformers_ok = False
172
 
173
  alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "unset")
 
176
  nvcc = shutil.which("nvcc")
177
  nvcc_ver = "n/a"
178
  if nvcc:
179
+ try:
180
+ nvcc_ver = subprocess.check_output([nvcc, "--version"], text=True).strip().splitlines()[-1]
181
+ except Exception:
182
+ nvcc_ver = "n/a"
183
 
184
  banner_lines = [
185
  "================== WAN MANAGER • ENV ==================",
186
+ f"Python : {python_ver}",
187
+ f"PyTorch : {torch_ver}",
188
+ f"CUDA (torch) : {cuda_rt}",
189
+ f"cuDNN : {cudnn_ver}",
190
+ f"CUDA available : {cuda_ok}",
191
+ f"GPU count : {n_gpu}",
192
  f"GPUs : {', '.join(devs) if devs else 'n/a'}",
193
  f"GPU VRAM : {', '.join(total_vram) if total_vram else 'n/a'}",
194
  f"Compute Capability : {', '.join(caps) if caps else 'n/a'}",
195
+ f"BF16 supported : {bf16_supported}",
196
+ f"TF32 allowed : {tf32_allowed}",
197
+ f"SDPA API : {sdpa_api}",
198
+ f"xFormers available : {xformers_ok}",
199
+ f"CUDA_VISIBLE_DEVICES: {visible}",
200
+ f"PYTORCH_CUDA_ALLOC_CONF: {alloc_conf}",
201
  f"nvcc : {nvcc_ver}",
202
  "=======================================================",
203
  ]
204
  print("\n".join(banner_lines))
205
 
206
+ # ---------- utils de imagem ----------
207
  def _round_multiple(self, x: int, multiple: int) -> int:
208
  return int(round(x / multiple) * multiple)
209
 
210
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
211
  w, h = image.size
212
+ if w == h:
213
+ return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS)
214
+
215
  ar = w / h
216
  nw, nh = w, h
217
+
218
+ # clamp superior
219
  if nw > self.MAX_DIMENSION or nh > self.MAX_DIMENSION:
220
  s = (self.MAX_DIMENSION / nw) if ar > 1 else (self.MAX_DIMENSION / nh)
221
  nw, nh = nw * s, nh * s
222
+
223
+ # clamp inferior
224
  if nw < self.MIN_DIMENSION or nh < self.MIN_DIMENSION:
225
  s = (self.MIN_DIMENSION / nh) if ar > 1 else (self.MIN_DIMENSION / nw)
226
  nw, nh = nw * s, nh * s
227
+
228
  fw = self._round_multiple(int(nw), self.DIMENSION_MULTIPLE)
229
  fh = self._round_multiple(int(nh), self.DIMENSION_MULTIPLE)
230
+
231
+ # mínimos finais coerentes
232
  fw = max(fw, self.MIN_DIMENSION if ar < 1 else self.SQUARE_SIZE)
233
  fh = max(fh, self.MIN_DIMENSION if ar > 1 else self.SQUARE_SIZE)
234
+
235
  return image.resize((fw, fh), Image.Resampling.LANCZOS)
236
 
237
  def resize_and_crop_to_match(self, target: Image.Image, ref: Image.Image) -> Image.Image:
 
243
  left, top = (nw - rw) // 2, (nh - rh) // 2
244
  return resized.crop((left, top, left + rw, top + rh))
245
 
246
+ # ---------- API ----------
247
  def generate_video_from_conditions(
248
  self,
249
+ images_condition_items: List[List[Any]], # [[image(Image), frame(int|str), peso(float)], ...]
250
  prompt: str,
251
  negative_prompt: Optional[str],
252
  duration_seconds: float,
 
256
  seed: int,
257
  randomize_seed: bool,
258
  output_type: str = "np",
259
+ ) -> Tuple[str, int]:
260
+ # validação
261
  if not images_condition_items or len(images_condition_items) < 2:
262
  raise ValueError("Forneça ao menos dois itens (início e fim).")
263
 
 
266
  end_image = items[-1][0]
267
  if start_image is None or end_image is None:
268
  raise ValueError("As imagens inicial e final não podem ser vazias.")
269
+ if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
270
+ raise TypeError("Patches devem ser PIL.Image.")
271
 
272
+ # handle opcional
273
  handle_image = items[1][0] if len(items) >= 3 else None
274
+
275
+ # pesos
276
  handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
277
  end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
278
 
279
+ # preprocess e alinhamento HxW
280
  processed_start = self.process_image_for_video(start_image)
281
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
282
  processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image else None
283
 
284
  H, W = processed_start.height, processed_start.width
 
 
 
 
 
 
 
 
 
285
 
286
+ # frames (pipeline ajusta para 4n+1 internamente, aqui só clamp)
287
+ num_frames = int(round(duration_seconds * self.FIXED_FPS))
288
+ num_frames = int(np.clip(num_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
289
+
290
+ # seed
291
  current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
292
  generator = torch.Generator().manual_seed(current_seed)
293
 
294
+ # argumentos base
 
 
295
  call_kwargs = dict(
296
+ image=processed_start,
297
+ last_image=processed_end,
298
+ prompt=prompt,
299
+ negative_prompt=negative_prompt if negative_prompt else self.default_negative_prompt,
300
+ height=H,
301
+ width=W,
302
+ num_frames=num_frames,
303
+ guidance_scale=float(guidance_scale),
304
+ guidance_scale_2=float(guidance_scale_2),
305
+ num_inference_steps=int(steps),
306
+ generator=generator,
307
+ output_type=output_type,
308
  )
309
 
310
+ # mapear frame da UI do handle índice latente alinhado a 1 (mod 4)
311
+ corrected_handle_index = int(items[1][1])
312
+
313
+ # Montar kwargs finais (com/sem handle)
 
 
 
 
 
 
 
 
 
 
 
 
314
  if processed_handle is not None:
315
+ kwargs = dict(
316
+ **call_kwargs,
317
+ handle_image=processed_handle,
318
+ handle_weight=float(handle_weight),
319
+ handle_latent_index=corrected_handle_index,
320
+ anchor_weight_last=float(end_weight),
321
+ )
322
+ else:
323
+ kwargs = dict(
324
+ **call_kwargs,
325
+ anchor_weight_last=float(end_weight),
326
+ )
327
+
328
+ # Execução com SDPA e fallback de backend
329
  result = None
330
+
331
+ result = self.pipe(**kwargs)
332
 
 
333
  frames = result.frames[0]
334
 
 
335
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
336
  video_path = tmp.name
337
  export_to_video(frames, video_path, fps=self.FIXED_FPS)
338
 
339
+ return video_path, current_seed