Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from diffusers.utils.torch_utils import randn_tensor | |
| def random_noise( | |
| tensor: torch.Tensor = None, | |
| shape: Tuple[int] = None, | |
| dtype: torch.dtype = None, | |
| device: torch.device = None, | |
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, | |
| noise_offset: Optional[float] = None, # typical value is 0.1 | |
| ) -> torch.Tensor: | |
| if tensor is not None: | |
| shape = tensor.shape | |
| device = tensor.device | |
| dtype = tensor.dtype | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator) | |
| if noise_offset is not None: | |
| # https://www.crosslabs.org//blog/diffusion-with-offset-noise | |
| noise += noise_offset * torch.randn( | |
| (tensor.shape[0], tensor.shape[1], 1, 1, 1), device | |
| ) | |
| return noise | |
| def video_fusion_noise( | |
| tensor: torch.Tensor = None, | |
| shape: Tuple[int] = None, | |
| dtype: torch.dtype = None, | |
| device: torch.device = None, | |
| w_ind_noise: float = 0.5, | |
| generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, | |
| initial_common_noise: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| if tensor is not None: | |
| shape = tensor.shape | |
| device = tensor.device | |
| dtype = tensor.dtype | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| batch_size, c, t, h, w = shape | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if not isinstance(generator, list): | |
| if initial_common_noise is not None: | |
| common_noise = initial_common_noise.to(device, dtype=dtype) | |
| else: | |
| common_noise = randn_tensor( | |
| (shape[0], shape[1], 1, shape[3], shape[4]), | |
| generator=generator, | |
| device=device, | |
| dtype=dtype, | |
| ) # common noise | |
| ind_noise = randn_tensor( | |
| shape, | |
| generator=generator, | |
| device=device, | |
| dtype=dtype, | |
| ) # individual noise | |
| s = torch.tensor(w_ind_noise, device=device, dtype=dtype) | |
| latents = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise | |
| else: | |
| latents = [] | |
| for i in range(batch_size): | |
| latent = video_fusion_noise( | |
| shape=(1, c, t, h, w), | |
| dtype=dtype, | |
| device=device, | |
| w_ind_noise=w_ind_noise, | |
| generator=generator[i], | |
| initial_common_noise=initial_common_noise, | |
| ) | |
| latents.append(latent) | |
| latents = torch.cat(latents, dim=0).to(device) | |
| return latents | |