import numpy as np import torch import torch.distributions as dist from torch import nn from modules.commons.conv import ConditionalConvBlocks from modules.commons.normalizing_flow.res_flow import ResFlow from modules.commons.wavenet import WN class FVAEEncoder(nn.Module): def __init__(self, c_in, hidden_size, c_latent, kernel_size, n_layers, c_cond=0, p_dropout=0, strides=[4], nn_type='wn'): super().__init__() self.strides = strides self.hidden_size = hidden_size if np.prod(strides) == 1: self.pre_net = nn.Conv1d(c_in, hidden_size, kernel_size=1) else: self.pre_net = nn.Sequential(*[ nn.Conv1d(c_in, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2) if i == 0 else nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2) for i, s in enumerate(strides) ]) if nn_type == 'wn': self.nn = WN(hidden_size, kernel_size, 1, n_layers, c_cond, p_dropout) elif nn_type == 'conv': self.nn = ConditionalConvBlocks( hidden_size, c_cond, hidden_size, None, kernel_size, layers_in_block=2, is_BTC=False, num_layers=n_layers) self.out_proj = nn.Conv1d(hidden_size, c_latent * 2, 1) self.latent_channels = c_latent def forward(self, x, nonpadding, cond): x = self.pre_net(x) nonpadding = nonpadding[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]] x = x * nonpadding x = self.nn(x, nonpadding=nonpadding, cond=cond) * nonpadding x = self.out_proj(x) m, logs = torch.split(x, self.latent_channels, dim=1) z = (m + torch.randn_like(m) * torch.exp(logs)) return z, m, logs, nonpadding class FVAEDecoder(nn.Module): def __init__(self, c_latent, hidden_size, out_channels, kernel_size, n_layers, c_cond=0, p_dropout=0, strides=[4], nn_type='wn'): super().__init__() self.strides = strides self.hidden_size = hidden_size self.pre_net = nn.Sequential(*[ nn.ConvTranspose1d(c_latent, hidden_size, kernel_size=s, stride=s) if i == 0 else nn.ConvTranspose1d(hidden_size, hidden_size, kernel_size=s, stride=s) for i, s in enumerate(strides) ]) if nn_type == 'wn': self.nn = WN(hidden_size, kernel_size, 1, n_layers, c_cond, p_dropout) elif nn_type == 'conv': self.nn = ConditionalConvBlocks( hidden_size, c_cond, hidden_size, [1] * n_layers, kernel_size, layers_in_block=2, is_BTC=False) self.out_proj = nn.Conv1d(hidden_size, out_channels, 1) def forward(self, x, nonpadding, cond): x = self.pre_net(x) x = x * nonpadding x = self.nn(x, nonpadding=nonpadding, cond=cond) * nonpadding x = self.out_proj(x) return x class FVAE(nn.Module): def __init__(self, c_in_out, hidden_size, c_latent, kernel_size, enc_n_layers, dec_n_layers, c_cond, strides, use_prior_flow, flow_hidden=None, flow_kernel_size=None, flow_n_steps=None, encoder_type='wn', decoder_type='wn'): super(FVAE, self).__init__() self.strides = strides self.hidden_size = hidden_size self.latent_size = c_latent self.use_prior_flow = use_prior_flow if np.prod(strides) == 1: self.g_pre_net = nn.Conv1d(c_cond, c_cond, kernel_size=1) else: self.g_pre_net = nn.Sequential(*[ nn.Conv1d(c_cond, c_cond, kernel_size=s * 2, stride=s, padding=s // 2) for i, s in enumerate(strides) ]) self.encoder = FVAEEncoder(c_in_out, hidden_size, c_latent, kernel_size, enc_n_layers, c_cond, strides=strides, nn_type=encoder_type) if use_prior_flow: self.prior_flow = ResFlow( c_latent, flow_hidden, flow_kernel_size, flow_n_steps, 4, c_cond=c_cond) self.decoder = FVAEDecoder(c_latent, hidden_size, c_in_out, kernel_size, dec_n_layers, c_cond, strides=strides, nn_type=decoder_type) self.prior_dist = dist.Normal(0, 1) def forward(self, x=None, nonpadding=None, cond=None, infer=False, noise_scale=1.0): """ :param x: [B, C_in_out, T] :param nonpadding: [B, 1, T] :param cond: [B, C_g, T] :return: """ if nonpadding is None: nonpadding = 1 cond_sqz = self.g_pre_net(cond) if not infer: z_q, m_q, logs_q, nonpadding_sqz = self.encoder(x, nonpadding, cond_sqz) q_dist = dist.Normal(m_q, logs_q.exp()) if self.use_prior_flow: logqx = q_dist.log_prob(z_q) z_p = self.prior_flow(z_q, nonpadding_sqz, cond_sqz) logpx = self.prior_dist.log_prob(z_p) loss_kl = ((logqx - logpx) * nonpadding_sqz).sum() / nonpadding_sqz.sum() / logqx.shape[1] else: loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist) loss_kl = (loss_kl * nonpadding_sqz).sum() / nonpadding_sqz.sum() / z_q.shape[1] z_p = None return z_q, loss_kl, z_p, m_q, logs_q else: latent_shape = [cond_sqz.shape[0], self.latent_size, cond_sqz.shape[2]] z_p = torch.randn(latent_shape).to(cond.device) * noise_scale if self.use_prior_flow: z_p = self.prior_flow(z_p, 1, cond_sqz, reverse=True) return z_p