x2XcarleX2x's picture
Update aduc_framework/utils/callbacks.py
0374496 verified
# aduc_framework/utils/callbacks.py
import imageio
import torch
import numpy as np
import tempfile
from diffusers.utils.export_utils import export_to_video
class DenoiseStepLogger:
"""
Callback que, em cada passo do denoising, decodifica a sequência de vídeo
inteira e a salva como um clipe MP4 individual.
"""
def __init__(self, pipe, fps=8):
self.pipe = pipe
self.fps = fps
# Armazena os caminhos para os vídeos gerados em cada passo
self.step_video_paths = []
self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
def decode_latents_to_video_path(self, latents: torch.Tensor, step: int) -> str:
"""
Decodifica um tensor latente 5D, salva o vídeo resultante em um arquivo
temporário e retorna o caminho para esse arquivo.
"""
correct_device = latents.device
latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
# O resultado já é um lote de frames de vídeo
frames = self.pipe.video_processor.postprocess_video(video=video_tensor, output_type="np")
with tempfile.NamedTemporaryFile(suffix=f"_step_{step+1}.mp4", delete=False) as tmp:
video_path = tmp.name
export_to_video(frames[0], video_path, fps=self.fps)
return video_path
def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
"""
Chamado pela pipeline a cada passo.
"""
print(f" -> Callback: Decodificando vídeo completo do passo de denoising {step+1}...")
latents = callback_kwargs["latents"]
# Gera o vídeo para o passo atual e armazena seu caminho
video_path = self.decode_latents_to_video_path(latents, step)
self.step_video_paths.append(video_path)
yield None, None, self.denoising_step_videos
return callback_kwargs