Spaces:
Runtime error
Runtime error
| from einops import rearrange | |
| from torch import nn | |
| import torch | |
| def decode_unet_latents_with_vae(vae: nn.Module, latents: torch.tensor): | |
| n_dim = latents.ndim | |
| batch_size = latents.shape[0] | |
| if n_dim == 5: | |
| latents = rearrange(latents, "b c f h w -> (b f) c h w") | |
| latents = 1 / vae.config.scaling_factor * latents | |
| video = vae.decode(latents, return_dict=False)[0] | |
| video = (video / 2 + 0.5).clamp(0, 1) | |
| if n_dim == 5: | |
| latents = rearrange(latents, "(b f) h w c -> b c f h w", b=batch_size) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| return video | |