x2XcarleX2x commited on
Commit
e7d2ed1
·
verified ·
1 Parent(s): 3d73884

Update aduc_framework/managers/vae_wan_manager.py

Browse files
aduc_framework/managers/vae_wan_manager.py CHANGED
@@ -85,13 +85,20 @@ class VaeWanManager:
85
  raise e
86
 
87
  def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
88
- """Converte uma imagem PIL para o formato de tensor esperado pelo VAE."""
89
  from PIL import ImageOps
90
  img = pil_image.convert("RGB")
91
  processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
92
  image_np = np.array(processed_img).astype(np.float32) / 255.0
93
- tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
94
- return (tensor * 2.0) - 1.0
 
 
 
 
 
 
 
95
 
96
  @torch.no_grad()
97
  def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
@@ -102,9 +109,9 @@ class VaeWanManager:
102
 
103
  latents_list = []
104
  for img in pil_images:
 
105
  pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype)
106
 
107
- # Usa a função oficial do diffusers para extrair os latentes
108
  encoder_output = self.vae.encode(pixel_tensor_gpu)
109
  latents = retrieve_latents(encoder_output)
110
 
@@ -121,7 +128,7 @@ class VaeWanManager:
121
 
122
  latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
123
 
124
- # Acessa a saída através do atributo .sample para compatibilidade
125
  decode_output = self.vae.decode(latent_tensor_gpu)
126
  pixels = decode_output.sample
127
 
 
85
  raise e
86
 
87
  def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
88
+ """Converte uma imagem PIL para o formato de tensor 5D esperado pelo VAE de vídeo."""
89
  from PIL import ImageOps
90
  img = pil_image.convert("RGB")
91
  processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
92
  image_np = np.array(processed_img).astype(np.float32) / 255.0
93
+
94
+ # Converte para (B, C, H, W)
95
+ tensor_4d = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
96
+ tensor_4d_normalized = (tensor_4d * 2.0) - 1.0
97
+
98
+ # Adiciona a dimensão de "frame" para criar um tensor 5D (B, C, F, H, W)
99
+ tensor_5d = tensor_4d_normalized.unsqueeze(2)
100
+
101
+ return tensor_5d
102
 
103
  @torch.no_grad()
104
  def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
 
109
 
110
  latents_list = []
111
  for img in pil_images:
112
+ # A função de pré-processamento agora retorna o tensor 5D correto
113
  pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype)
114
 
 
115
  encoder_output = self.vae.encode(pixel_tensor_gpu)
116
  latents = retrieve_latents(encoder_output)
117
 
 
128
 
129
  latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
130
 
131
+ # Acessa a saída através do atributo .sample
132
  decode_output = self.vae.decode(latent_tensor_gpu)
133
  pixels = decode_output.sample
134