Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import numpy as np | |
| import torch | |
| from torch_utils import misc | |
| from torch_utils import persistence | |
| from training.models import * | |
| #---------------------------------------------------------------------------- | |
| class MappingNetwork(torch.nn.Module): | |
| def __init__(self, | |
| z_dim, # Input latent (Z) dimensionality, 0 = no latent. | |
| c_dim, # Conditioning label (C) dimensionality, 0 = no label. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| num_ws, # Number of intermediate latents to output, None = do not broadcast. | |
| num_layers = 8, # Number of mapping layers. | |
| embed_features = None, # Label embedding dimensionality, None = same as w_dim. | |
| layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. | |
| activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. | |
| lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. | |
| w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. | |
| ): | |
| super().__init__() | |
| self.z_dim = z_dim | |
| self.c_dim = c_dim | |
| self.w_dim = w_dim | |
| self.num_ws = num_ws | |
| self.num_layers = num_layers | |
| self.w_avg_beta = w_avg_beta | |
| if embed_features is None: | |
| embed_features = w_dim | |
| if c_dim == 0: | |
| embed_features = 0 | |
| if layer_features is None: | |
| layer_features = w_dim | |
| features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] | |
| if c_dim > 0: | |
| self.embed = FullyConnectedLayer(c_dim, embed_features) | |
| for idx in range(num_layers): | |
| in_features = features_list[idx] | |
| out_features = features_list[idx + 1] | |
| layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) | |
| setattr(self, f'fc{idx}', layer) | |
| if num_ws is not None and w_avg_beta is not None: | |
| self.register_buffer('w_avg', torch.zeros([w_dim])) | |
| def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): | |
| # Embed, normalize, and concat inputs. | |
| x = None | |
| with torch.autograd.profiler.record_function('input'): | |
| if self.z_dim > 0: | |
| misc.assert_shape(z, [None, self.z_dim]) | |
| x = normalize_2nd_moment(z.to(torch.float32)) | |
| if self.c_dim > 0: | |
| misc.assert_shape(c, [None, self.c_dim]) | |
| y = normalize_2nd_moment(self.embed(c.to(torch.float32))) | |
| x = torch.cat([x, y], dim=1) if x is not None else y | |
| # Main layers. | |
| for idx in range(self.num_layers): | |
| layer = getattr(self, f'fc{idx}') | |
| x = layer(x) | |
| # Update moving average of W. | |
| if self.w_avg_beta is not None and self.training and not skip_w_avg_update: | |
| with torch.autograd.profiler.record_function('update_w_avg'): | |
| self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | |
| # Broadcast. | |
| if self.num_ws is not None: | |
| with torch.autograd.profiler.record_function('broadcast'): | |
| x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | |
| # Apply truncation. | |
| if truncation_psi != 1: | |
| with torch.autograd.profiler.record_function('truncate'): | |
| assert self.w_avg_beta is not None | |
| if self.num_ws is None or truncation_cutoff is None: | |
| x = self.w_avg.lerp(x, truncation_psi) | |
| else: | |
| x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | |
| return x | |
| #---------------------------------------------------------------------------- | |
| class EncoderNetwork(torch.nn.Module): | |
| def __init__(self, | |
| c_dim, # Conditioning label (C) dimensionality. | |
| z_dim, # Input latent (Z) dimensionality. | |
| img_resolution, # Input resolution. | |
| img_channels, # Number of input color channels. | |
| architecture = 'orig', # Architecture: 'orig', 'skip', 'resnet'. | |
| channel_base = 16384, # Overall multiplier for the number of channels. | |
| channel_max = 512, # Maximum number of channels in any layer. | |
| num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
| conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
| cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
| block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
| mapping_kwargs = {}, # Arguments for MappingNetwork. | |
| epilogue_kwargs = {}, # Arguments for EncoderEpilogue. | |
| ): | |
| super().__init__() | |
| self.c_dim = c_dim | |
| self.z_dim = z_dim | |
| self.img_resolution = img_resolution | |
| self.img_resolution_log2 = int(np.log2(img_resolution)) | |
| self.img_channels = img_channels | |
| self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
| channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
| fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
| if cmap_dim is None: | |
| cmap_dim = channels_dict[4] | |
| if c_dim == 0: | |
| cmap_dim = 0 | |
| common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
| cur_layer_idx = 0 | |
| for res in self.block_resolutions: | |
| in_channels = channels_dict[res] if res < img_resolution else 0 | |
| tmp_channels = channels_dict[res] | |
| out_channels = channels_dict[res // 2] | |
| use_fp16 = (res >= fp16_resolution) | |
| block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
| first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
| setattr(self, f'b{res}', block) | |
| cur_layer_idx += block.num_layers | |
| if c_dim > 0: | |
| self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
| self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, **common_kwargs) | |
| def forward(self, img, c, **block_kwargs): | |
| x = None | |
| feats = {} | |
| for res in self.block_resolutions: | |
| block = getattr(self, f'b{res}') | |
| x, img, feat = block(x, img, **block_kwargs) | |
| feats[res] = feat | |
| cmap = None | |
| if self.c_dim > 0: | |
| cmap = self.mapping(None, c) | |
| x, const_e = self.b4(x, cmap) | |
| feats[4] = const_e | |
| B, _ = x.shape | |
| z = torch.randn((B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device) ## Noise for Co-Modulation | |
| return x, z, feats ## 1/2, 1/4, 1/8, 1/16, 1/32, 1/64 | |
| #---------------------------------------------------------------------------- | |
| class SynthesisNetwork(torch.nn.Module): | |
| def __init__(self, | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| z_dim, # Output Latent (Z) dimensionality. | |
| img_resolution, # Output image resolution. | |
| img_channels, # Number of color channels. | |
| channel_base = 16384, # Overall multiplier for the number of channels. | |
| channel_max = 512, # Maximum number of channels in any layer. | |
| num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
| **block_kwargs, # Arguments for SynthesisBlock. | |
| ): | |
| assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 | |
| super().__init__() | |
| self.w_dim = w_dim | |
| self.img_resolution = img_resolution | |
| self.img_resolution_log2 = int( np.log2(img_resolution)) | |
| self.img_channels = img_channels | |
| self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] | |
| channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} | |
| fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
| self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), z_dim=z_dim*2, resolution=4) | |
| self.num_ws = self.img_resolution_log2 * 2 - 2 | |
| for res in self.block_resolutions: | |
| if res // 2 in channels_dict.keys(): | |
| in_channels = channels_dict[res // 2] if res > 4 else 0 | |
| else: | |
| in_channels = min(channel_base // (res // 2) , channel_max) | |
| out_channels = channels_dict[res] | |
| use_fp16 = (res >= fp16_resolution) | |
| is_last = (res == self.img_resolution) | |
| block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, | |
| img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) | |
| setattr(self, f'b{res}', block) | |
| def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): | |
| img = None | |
| x, img = self.foreword(x_global, ws, feats, img) | |
| for res in self.block_resolutions: | |
| block = getattr(self, f'b{res}') | |
| mod_vector0 = [] | |
| mod_vector0.append(ws[:, int(np.log2(res))*2-5]) | |
| mod_vector0.append(x_global.clone()) | |
| mod_vector0 = torch.cat(mod_vector0, dim = 1) | |
| mod_vector1 = [] | |
| mod_vector1.append(ws[:, int(np.log2(res))*2-4]) | |
| mod_vector1.append(x_global.clone()) | |
| mod_vector1 = torch.cat(mod_vector1, dim = 1) | |
| mod_vector_rgb = [] | |
| mod_vector_rgb.append(ws[:, int(np.log2(res))*2-3]) | |
| mod_vector_rgb.append(x_global.clone()) | |
| mod_vector_rgb = torch.cat(mod_vector_rgb, dim = 1) | |
| # ic(x.shape) | |
| x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) | |
| # ic(x.shape) | |
| # ic('--------') | |
| return img | |
| #---------------------------------------------------------------------------- | |
| class Generator(torch.nn.Module): | |
| def __init__(self, | |
| z_dim, # Input latent (Z) dimensionality. | |
| c_dim, # Conditioning label (C) dimensionality. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| img_resolution, # Output resolution. | |
| img_channels, # Number of output color channels. | |
| encoder_kwargs = {}, # Arguments for EncoderNetwork. | |
| mapping_kwargs = {}, # Arguments for MappingNetwork. | |
| synthesis_kwargs = {}, # Arguments for SynthesisNetwork. | |
| ): | |
| super().__init__() | |
| self.z_dim = z_dim | |
| self.c_dim = c_dim | |
| self.w_dim = w_dim | |
| self.img_resolution = img_resolution | |
| self.img_channels = img_channels | |
| self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, img_channels=img_channels, **encoder_kwargs) | |
| self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) | |
| self.num_ws = self.synthesis.num_ws | |
| self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) | |
| def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): | |
| mask = img[:, -1].unsqueeze(1) | |
| x_global, z, feats = self.encoder(img, c) | |
| ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) | |
| img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) | |
| # exit() | |
| return img | |
| #---------------------------------------------------------------------------- | |
| class Discriminator(torch.nn.Module): | |
| def __init__(self, | |
| c_dim, # Conditioning label (C) dimensionality. | |
| img_resolution, # Input resolution. | |
| img_channels, # Number of input color channels. | |
| architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
| channel_base = 16384, # Overall multiplier for the number of channels. | |
| channel_max = 512, # Maximum number of channels in any layer. | |
| num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
| conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
| cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
| block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
| mapping_kwargs = {}, # Arguments for MappingNetwork. | |
| epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
| ): | |
| super().__init__() | |
| self.c_dim = c_dim | |
| self.img_resolution = img_resolution | |
| self.img_resolution_log2 = int(np.log2(img_resolution)) | |
| self.img_channels = img_channels | |
| self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
| channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
| fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
| if cmap_dim is None: | |
| cmap_dim = channels_dict[4] | |
| if c_dim == 0: | |
| cmap_dim = 0 | |
| common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
| cur_layer_idx = 0 | |
| for res in self.block_resolutions: | |
| in_channels = channels_dict[res] if res < img_resolution else 0 | |
| tmp_channels = channels_dict[res] | |
| out_channels = channels_dict[res // 2] | |
| use_fp16 = (res >= fp16_resolution) | |
| block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
| first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
| setattr(self, f'b{res}', block) | |
| cur_layer_idx += block.num_layers | |
| if c_dim > 0: | |
| self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
| self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
| def forward(self, img, c, **block_kwargs): | |
| x = None | |
| for res in self.block_resolutions: | |
| block = getattr(self, f'b{res}') | |
| x, img = block(x, img, **block_kwargs) | |
| cmap = None | |
| if self.c_dim > 0: | |
| cmap = self.mapping(None, c) | |
| x = self.b4(x, img, cmap) | |
| return x | |
| #---------------------------------------------------------------------------- | |