Delete pipeline_bddm.py
Browse files- pipeline_bddm.py +0 -304
pipeline_bddm.py
DELETED
|
@@ -1,304 +0,0 @@
|
|
| 1 |
-
#!/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
########################################################################
|
| 4 |
-
#
|
| 5 |
-
# DiffWave: A Versatile Diffusion Model for Audio Synthesis
|
| 6 |
-
# (https://arxiv.org/abs/2009.09761)
|
| 7 |
-
# Modified from https://github.com/philsyn/DiffWave-Vocoder
|
| 8 |
-
#
|
| 9 |
-
# Author: Max W. Y. Lam ([email protected])
|
| 10 |
-
# Copyright (c) 2021Tencent. All Rights Reserved
|
| 11 |
-
#
|
| 12 |
-
########################################################################
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
import math
|
| 16 |
-
import numpy as np
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import torch.nn.functional as F
|
| 20 |
-
import tqdm
|
| 21 |
-
|
| 22 |
-
from diffusers.modeling_utils import ModelMixin
|
| 23 |
-
from diffusers.configuration_utils import ConfigMixin
|
| 24 |
-
from diffusers.pipeline_utils import DiffusionPipeline
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
| 28 |
-
"""
|
| 29 |
-
Embed a diffusion step $t$ into a higher dimensional space
|
| 30 |
-
E.g. the embedding vector in the 128-dimensional space is
|
| 31 |
-
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
|
| 32 |
-
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
|
| 33 |
-
|
| 34 |
-
Parameters:
|
| 35 |
-
diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
|
| 36 |
-
diffusion steps for batch data
|
| 37 |
-
diffusion_step_embed_dim_in (int, default=128):
|
| 38 |
-
dimensionality of the embedding space for discrete diffusion steps
|
| 39 |
-
Returns:
|
| 40 |
-
the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
assert diffusion_step_embed_dim_in % 2 == 0
|
| 44 |
-
|
| 45 |
-
half_dim = diffusion_step_embed_dim_in // 2
|
| 46 |
-
_embed = np.log(10000) / (half_dim - 1)
|
| 47 |
-
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
|
| 48 |
-
_embed = diffusion_steps * _embed
|
| 49 |
-
diffusion_step_embed = torch.cat((torch.sin(_embed),
|
| 50 |
-
torch.cos(_embed)), 1)
|
| 51 |
-
return diffusion_step_embed
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
"""
|
| 55 |
-
Below scripts were borrowed from
|
| 56 |
-
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def swish(x):
|
| 61 |
-
return x * torch.sigmoid(x)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# dilated conv layer with kaiming_normal initialization
|
| 65 |
-
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
|
| 66 |
-
class Conv(nn.Module):
|
| 67 |
-
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
|
| 68 |
-
super().__init__()
|
| 69 |
-
self.padding = dilation * (kernel_size - 1) // 2
|
| 70 |
-
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
|
| 71 |
-
dilation=dilation, padding=self.padding)
|
| 72 |
-
self.conv = nn.utils.weight_norm(self.conv)
|
| 73 |
-
nn.init.kaiming_normal_(self.conv.weight)
|
| 74 |
-
|
| 75 |
-
def forward(self, x):
|
| 76 |
-
out = self.conv(x)
|
| 77 |
-
return out
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
# conv1x1 layer with zero initialization
|
| 81 |
-
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
|
| 82 |
-
class ZeroConv1d(nn.Module):
|
| 83 |
-
def __init__(self, in_channel, out_channel):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
|
| 86 |
-
self.conv.weight.data.zero_()
|
| 87 |
-
self.conv.bias.data.zero_()
|
| 88 |
-
|
| 89 |
-
def forward(self, x):
|
| 90 |
-
out = self.conv(x)
|
| 91 |
-
return out
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# every residual block (named residual layer in paper)
|
| 95 |
-
# contains one noncausal dilated conv
|
| 96 |
-
class ResidualBlock(nn.Module):
|
| 97 |
-
def __init__(self, res_channels, skip_channels, dilation,
|
| 98 |
-
diffusion_step_embed_dim_out):
|
| 99 |
-
super().__init__()
|
| 100 |
-
self.res_channels = res_channels
|
| 101 |
-
|
| 102 |
-
# Use a FC layer for diffusion step embedding
|
| 103 |
-
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
|
| 104 |
-
|
| 105 |
-
# Dilated conv layer
|
| 106 |
-
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
|
| 107 |
-
kernel_size=3, dilation=dilation)
|
| 108 |
-
|
| 109 |
-
# Add mel spectrogram upsampler and conditioner conv1x1 layer
|
| 110 |
-
self.upsample_conv2d = nn.ModuleList()
|
| 111 |
-
for s in [16, 16]:
|
| 112 |
-
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
|
| 113 |
-
padding=(1, s // 2),
|
| 114 |
-
stride=(1, s))
|
| 115 |
-
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
|
| 116 |
-
nn.init.kaiming_normal_(conv_trans2d.weight)
|
| 117 |
-
self.upsample_conv2d.append(conv_trans2d)
|
| 118 |
-
|
| 119 |
-
# 80 is mel bands
|
| 120 |
-
self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
|
| 121 |
-
|
| 122 |
-
# Residual conv1x1 layer, connect to next residual layer
|
| 123 |
-
self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
|
| 124 |
-
self.res_conv = nn.utils.weight_norm(self.res_conv)
|
| 125 |
-
nn.init.kaiming_normal_(self.res_conv.weight)
|
| 126 |
-
|
| 127 |
-
# Skip conv1x1 layer, add to all skip outputs through skip connections
|
| 128 |
-
self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
|
| 129 |
-
self.skip_conv = nn.utils.weight_norm(self.skip_conv)
|
| 130 |
-
nn.init.kaiming_normal_(self.skip_conv.weight)
|
| 131 |
-
|
| 132 |
-
def forward(self, input_data):
|
| 133 |
-
x, mel_spec, diffusion_step_embed = input_data
|
| 134 |
-
h = x
|
| 135 |
-
batch_size, n_channels, seq_len = x.shape
|
| 136 |
-
assert n_channels == self.res_channels
|
| 137 |
-
|
| 138 |
-
# Add in diffusion step embedding
|
| 139 |
-
part_t = self.fc_t(diffusion_step_embed)
|
| 140 |
-
part_t = part_t.view([batch_size, self.res_channels, 1])
|
| 141 |
-
h += part_t
|
| 142 |
-
|
| 143 |
-
# Dilated conv layer
|
| 144 |
-
h = self.dilated_conv_layer(h)
|
| 145 |
-
|
| 146 |
-
# Upsample spectrogram to size of audio
|
| 147 |
-
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
| 148 |
-
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
| 149 |
-
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
| 150 |
-
mel_spec = torch.squeeze(mel_spec, dim=1)
|
| 151 |
-
|
| 152 |
-
assert mel_spec.size(2) >= seq_len
|
| 153 |
-
if mel_spec.size(2) > seq_len:
|
| 154 |
-
mel_spec = mel_spec[:, :, :seq_len]
|
| 155 |
-
|
| 156 |
-
mel_spec = self.mel_conv(mel_spec)
|
| 157 |
-
h += mel_spec
|
| 158 |
-
|
| 159 |
-
# Gated-tanh nonlinearity
|
| 160 |
-
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
|
| 161 |
-
|
| 162 |
-
# Residual and skip outputs
|
| 163 |
-
res = self.res_conv(out)
|
| 164 |
-
assert x.shape == res.shape
|
| 165 |
-
skip = self.skip_conv(out)
|
| 166 |
-
|
| 167 |
-
# Normalize for training stability
|
| 168 |
-
return (x + res) * math.sqrt(0.5), skip
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
class ResidualGroup(nn.Module):
|
| 172 |
-
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
|
| 173 |
-
diffusion_step_embed_dim_in,
|
| 174 |
-
diffusion_step_embed_dim_mid,
|
| 175 |
-
diffusion_step_embed_dim_out):
|
| 176 |
-
super().__init__()
|
| 177 |
-
self.num_res_layers = num_res_layers
|
| 178 |
-
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
|
| 179 |
-
|
| 180 |
-
# Use the shared two FC layers for diffusion step embedding
|
| 181 |
-
self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
|
| 182 |
-
self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
|
| 183 |
-
|
| 184 |
-
# Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
|
| 185 |
-
self.residual_blocks = nn.ModuleList()
|
| 186 |
-
for n in range(self.num_res_layers):
|
| 187 |
-
self.residual_blocks.append(
|
| 188 |
-
ResidualBlock(res_channels, skip_channels,
|
| 189 |
-
dilation=2 ** (n % dilation_cycle),
|
| 190 |
-
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
|
| 191 |
-
|
| 192 |
-
def forward(self, input_data):
|
| 193 |
-
x, mel_spectrogram, diffusion_steps = input_data
|
| 194 |
-
|
| 195 |
-
# Embed diffusion step t
|
| 196 |
-
diffusion_step_embed = calc_diffusion_step_embedding(
|
| 197 |
-
diffusion_steps, self.diffusion_step_embed_dim_in)
|
| 198 |
-
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
|
| 199 |
-
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
|
| 200 |
-
|
| 201 |
-
# Pass all residual layers
|
| 202 |
-
h = x
|
| 203 |
-
skip = 0
|
| 204 |
-
for n in range(self.num_res_layers):
|
| 205 |
-
# Use the output from last residual layer
|
| 206 |
-
h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
|
| 207 |
-
# Accumulate all skip outputs
|
| 208 |
-
skip += skip_n
|
| 209 |
-
|
| 210 |
-
# Normalize for training stability
|
| 211 |
-
return skip * math.sqrt(1.0 / self.num_res_layers)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class DiffWave(ModelMixin, ConfigMixin):
|
| 215 |
-
def __init__(
|
| 216 |
-
self,
|
| 217 |
-
in_channels=1,
|
| 218 |
-
res_channels=128,
|
| 219 |
-
skip_channels=128,
|
| 220 |
-
out_channels=1,
|
| 221 |
-
num_res_layers=30,
|
| 222 |
-
dilation_cycle=10,
|
| 223 |
-
diffusion_step_embed_dim_in=128,
|
| 224 |
-
diffusion_step_embed_dim_mid=512,
|
| 225 |
-
diffusion_step_embed_dim_out=512,
|
| 226 |
-
):
|
| 227 |
-
super().__init__()
|
| 228 |
-
|
| 229 |
-
# register all init arguments with self.register
|
| 230 |
-
self.register(
|
| 231 |
-
in_channels=in_channels,
|
| 232 |
-
res_channels=res_channels,
|
| 233 |
-
skip_channels=skip_channels,
|
| 234 |
-
out_channels=out_channels,
|
| 235 |
-
num_res_layers=num_res_layers,
|
| 236 |
-
dilation_cycle=dilation_cycle,
|
| 237 |
-
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
|
| 238 |
-
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
|
| 239 |
-
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
# Initial conv1x1 with relu
|
| 244 |
-
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
|
| 245 |
-
# All residual layers
|
| 246 |
-
self.residual_layer = ResidualGroup(res_channels,
|
| 247 |
-
skip_channels,
|
| 248 |
-
num_res_layers,
|
| 249 |
-
dilation_cycle,
|
| 250 |
-
diffusion_step_embed_dim_in,
|
| 251 |
-
diffusion_step_embed_dim_mid,
|
| 252 |
-
diffusion_step_embed_dim_out)
|
| 253 |
-
# Final conv1x1 -> relu -> zeroconv1x1
|
| 254 |
-
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
|
| 255 |
-
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
|
| 256 |
-
|
| 257 |
-
def forward(self, input_data):
|
| 258 |
-
audio, mel_spectrogram, diffusion_steps = input_data
|
| 259 |
-
x = audio
|
| 260 |
-
x = self.init_conv(x).clone()
|
| 261 |
-
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
|
| 262 |
-
return self.final_conv(x)
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
class BDDM(DiffusionPipeline):
|
| 266 |
-
def __init__(self, diffwave, noise_scheduler):
|
| 267 |
-
super().__init__()
|
| 268 |
-
noise_scheduler = noise_scheduler.set_format("pt")
|
| 269 |
-
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
| 270 |
-
|
| 271 |
-
@torch.no_grad()
|
| 272 |
-
def __call__(self, mel_spectrogram, generator, torch_device=None):
|
| 273 |
-
if torch_device is None:
|
| 274 |
-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 275 |
-
|
| 276 |
-
self.diffwave.to(torch_device)
|
| 277 |
-
|
| 278 |
-
mel_spectrogram = mel_spectrogram.to(torch_device)
|
| 279 |
-
audio_length = mel_spectrogram.size(-1) * 256
|
| 280 |
-
audio_size = (1, 1, audio_length)
|
| 281 |
-
|
| 282 |
-
# Sample gaussian noise to begin loop
|
| 283 |
-
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
| 284 |
-
|
| 285 |
-
timestep_values = self.noise_scheduler.timestep_values
|
| 286 |
-
num_prediction_steps = len(self.noise_scheduler)
|
| 287 |
-
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
| 288 |
-
# 1. predict noise residual
|
| 289 |
-
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
| 290 |
-
residual = self.diffwave((audio, mel_spectrogram, ts))
|
| 291 |
-
|
| 292 |
-
# 2. predict previous mean of audio x_t-1
|
| 293 |
-
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
| 294 |
-
|
| 295 |
-
# 3. optionally sample variance
|
| 296 |
-
variance = 0
|
| 297 |
-
if t > 0:
|
| 298 |
-
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
| 299 |
-
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
| 300 |
-
|
| 301 |
-
# 4. set current audio to prev_audio: x_t -> x_t-1
|
| 302 |
-
audio = pred_prev_audio + variance
|
| 303 |
-
|
| 304 |
-
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|