Spaces:
Running
on
L4
Running
on
L4
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import repeat | |
| from torch import nn | |
| from .transformer import MultiviewTransformer | |
| def timestep_embedding( | |
| timesteps: torch.Tensor, | |
| dim: int, | |
| max_period: int = 10000, | |
| repeat_only: bool = False, | |
| ) -> torch.Tensor: | |
| if not repeat_only: | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32) | |
| / half | |
| ).to(device=timesteps.device) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat( | |
| [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 | |
| ) | |
| else: | |
| embedding = repeat(timesteps, "b -> b d", d=dim) | |
| return embedding | |
| class Upsample(nn.Module): | |
| def __init__(self, channels: int, out_channels: int | None = None): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| assert x.shape[1] == self.channels | |
| x = F.interpolate(x, scale_factor=2, mode="nearest") | |
| x = self.conv(x) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, channels: int, out_channels: int | None = None): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| assert x.shape[1] == self.channels | |
| return self.op(x) | |
| class GroupNorm32(nn.GroupNorm): | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return super().forward(input.float()).type(input.dtype) | |
| class TimestepEmbedSequential(nn.Sequential): | |
| def forward( # type: ignore[override] | |
| self, | |
| x: torch.Tensor, | |
| emb: torch.Tensor, | |
| context: torch.Tensor, | |
| dense_emb: torch.Tensor, | |
| num_frames: int, | |
| ) -> torch.Tensor: | |
| for layer in self: | |
| if isinstance(layer, MultiviewTransformer): | |
| assert num_frames is not None | |
| x = layer(x, context, num_frames) | |
| elif isinstance(layer, ResBlock): | |
| x = layer(x, emb, dense_emb) | |
| else: | |
| x = layer(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| emb_channels: int, | |
| out_channels: int | None, | |
| dense_in_channels: int, | |
| dropout: float, | |
| ): | |
| super().__init__() | |
| out_channels = out_channels or channels | |
| self.in_layers = nn.Sequential( | |
| GroupNorm32(32, channels), | |
| nn.SiLU(), | |
| nn.Conv2d(channels, out_channels, 3, 1, 1), | |
| ) | |
| self.emb_layers = nn.Sequential( | |
| nn.SiLU(), nn.Linear(emb_channels, out_channels) | |
| ) | |
| self.dense_emb_layers = nn.Sequential( | |
| nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0) | |
| ) | |
| self.out_layers = nn.Sequential( | |
| GroupNorm32(32, out_channels), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1), | |
| ) | |
| if out_channels == channels: | |
| self.skip_connection = nn.Identity() | |
| else: | |
| self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0) | |
| def forward( | |
| self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor | |
| ) -> torch.Tensor: | |
| in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | |
| h = in_rest(x) | |
| dense = self.dense_emb_layers( | |
| F.interpolate( | |
| dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True | |
| ) | |
| ).type(h.dtype) | |
| dense_scale, dense_shift = torch.chunk(dense, 2, dim=1) | |
| h = h * (1 + dense_scale) + dense_shift | |
| h = in_conv(h) | |
| emb_out = self.emb_layers(emb).type(h.dtype) | |
| # TODO(hangg): Optimize this? | |
| while len(emb_out.shape) < len(h.shape): | |
| emb_out = emb_out[..., None] | |
| h = h + emb_out | |
| h = self.out_layers(h) | |
| h = self.skip_connection(x) + h | |
| return h | |