Spaces:
Paused
Paused
| # aduc_framework/tools/pipeline_patches.py (Central de Modificações ADUC) | |
| import torch | |
| import logging | |
| from typing import List, Optional, Union | |
| # --- Importa os tipos da nossa arquitetura --- | |
| from ..types import LatentConditioningItem | |
| # --- Importa as classes originais que vamos modificar --- | |
| # Usamos try-except para permitir que o linter analise o arquivo mesmo sem as dependências. | |
| try: | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from ..managers.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem | |
| from ltx_video.models.autoencoders.vae_encode import latent_to_pixel_coords | |
| from diffusers.utils.torch_utils import randn_tensor | |
| except ImportError: | |
| WanImageToVideoPipeline = None | |
| LTXVideoPipeline = None | |
| ConditioningItem = None | |
| latent_to_pixel_coords = None | |
| randn_tensor = None | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================== | |
| # PATCH #1: Pipeline WanImageToVideo (Wan2.2) | |
| # Objetivo: Ensinar a pipeline a usar `LatentConditioningItem` para controle ADUC. | |
| # ============================================================================== | |
| def prepare_latents_patch_for_wan_i2v( | |
| self: WanImageToVideoPipeline, | |
| conditioning_items: List[LatentConditioningItem], | |
| batch_size: int, | |
| num_channels_latents: int, | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| generator, | |
| latents: Optional[torch.Tensor] = None, | |
| **kwargs # Aceita e ignora outros argumentos como 'image', 'last_image' | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Monkey patch para a pipeline WanImageToVideo, permitindo o uso de LatentConditioningItem.""" | |
| num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 | |
| latent_height = height // self.vae_scale_factor_spatial | |
| latent_width = width // self.vae_scale_factor_spatial | |
| shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) | |
| init_latents = latents if latents is not None else torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| init_latents = init_latents.to(device=device, dtype=dtype) | |
| mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width, device=device, dtype=dtype) | |
| mask_lat_size[:, :, 1:] = 0 | |
| first_frame_mask = mask_lat_size[:, :, 0:1] | |
| first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) | |
| mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) | |
| mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) | |
| mask_lat_size = mask_lat_size.transpose(1, 2).to(init_latents.device) | |
| logger.info(f"WAN_PATCH: Aplicando {len(conditioning_items)} itens de condicionamento.") | |
| for item in conditioning_items: | |
| media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device) | |
| frame_idx, strength = item.media_frame_number, item.conditioning_strength | |
| if frame_idx >= num_latent_frames: | |
| logger.warning(f"WAN_PATCH: frame_idx {frame_idx} fora dos limites. Pulando.") | |
| continue | |
| f_l, h_l, w_l = media_item_latents.shape[-3:] | |
| init_latents[:, :, frame_idx:frame_idx+f_l, :h_l, :w_l] = torch.lerp( | |
| init_latents[:, :, frame_idx:frame_idx+f_l, :h_l, :w_l], media_item_latents, strength | |
| ) | |
| mask_lat_size[:, :, frame_idx, :h_l, :w_l] = strength | |
| condition = torch.concat([mask_lat_size, init_latents], dim=1) | |
| return init_latents, condition | |
| # ============================================================================== | |
| # PATCH #2: Pipeline LTXVideo (LTX) | |
| # Objetivo: Ensinar a pipeline a usar `LatentConditioningItem` para controle ADUC. | |
| # ============================================================================== | |
| def prepare_conditioning_patch_for_ltx( | |
| self: "LTXVideoPipeline", | |
| conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]], | |
| init_latents: torch.Tensor, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| vae_per_channel_normalize: bool = False, | |
| generator=None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: | |
| """Monkey patch para a pipeline LTX-Video, focando no uso de LatentConditioningItem.""" | |
| if not conditioning_items: | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| return init_latents, init_pixel_coords, None, 0 | |
| init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], [] | |
| extra_conditioning_num_latents = 0 | |
| logger.info(f"LTX_PATCH: Aplicando {len(conditioning_items)} itens de condicionamento.") | |
| for item in conditioning_items: | |
| if not isinstance(item, LatentConditioningItem): | |
| logger.warning("LTX_PATCH: Item de condicionamento não é um LatentConditioningItem e será ignorado.") | |
| continue | |
| media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device) | |
| media_frame_number, strength = item.media_frame_number, item.conditioning_strength | |
| if media_frame_number == 0: | |
| f_l, h_l, w_l = media_item_latents.shape[-3:] | |
| init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength) | |
| init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength | |
| else: | |
| noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype) | |
| media_item_latents = torch.lerp(noise, media_item_latents, strength) | |
| patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| pixel_coords[:, 0] += media_frame_number | |
| extra_conditioning_num_latents += patched_latents.shape[1] | |
| new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents.append(patched_latents) | |
| extra_conditioning_pixel_coords.append(pixel_coords) | |
| extra_conditioning_mask.append(new_mask) | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) | |
| init_conditioning_mask = init_conditioning_mask.squeeze(-1) | |
| if extra_conditioning_latents: | |
| init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) | |
| init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) | |
| init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) | |
| return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents | |
| # ============================================================================== | |
| # FUNÇÃO DE APLICAÇÃO CENTRAL | |
| # ============================================================================== | |
| def apply_aduc_patches(): | |
| """Função central para aplicar todos os nossos patches ADUC-SDR.""" | |
| logger.info("--- Central de Patches ADUC-SDR: Aplicando modificações ---") | |
| # Aplica o patch na pipeline do Wan2.2 | |
| #if WanImageToVideoPipeline: | |
| # logger.info("-> Modificando 'WanImageToVideoPipeline.prepare_latents'...") | |
| # WanImageToVideoPipeline.prepare_latents = prepare_latents_patch_for_wan_i2v | |
| #else: | |
| # logger.warning("-> WanImageToVideoPipeline não encontrada. Patch pulado.") | |
| # Aplica o patch na pipeline do LTX | |
| if LTXVideoPipeline: | |
| logger.info("-> Modificando 'LTXVideoPipeline.prepare_conditioning'...") | |
| LTXVideoPipeline.prepare_conditioning = prepare_conditioning_patch_for_ltx | |
| else: | |
| logger.warning("-> LTXVideoPipeline não encontrada. Patch pulado.") | |
| logger.info("--- Modificações de pipeline concluídas ---") |