import logging import math import torch import torch.nn as nn logger = logging.getLogger(__name__) @torch.jit.script def _fused_tanh_sigmoid(h): a, b = h.chunk(2, dim=1) h = a.tanh() * b.sigmoid() return h class WNLayer(nn.Module): """ A DiffWave-like WN """ def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation): super().__init__() local_output_dim = hidden_dim * 2 if global_dim is not None: self.gconv = nn.Conv1d(global_dim, hidden_dim, 1) if local_dim is not None: self.lconv = nn.Conv1d(local_dim, local_output_dim, 1) self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same") self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1) def forward(self, z, l, g): identity = z if g is not None: if g.dim() == 2: g = g.unsqueeze(-1) z = z + self.gconv(g) z = self.dconv(z) if l is not None: z = z + self.lconv(l) z = _fused_tanh_sigmoid(z) h = self.out(z) z, s = h.chunk(2, dim=1) o = (z + identity) / math.sqrt(2) return o, s class WN(nn.Module): def __init__( self, input_dim, output_dim, local_dim=None, global_dim=None, n_layers=30, kernel_size=3, dilation_cycle=5, hidden_dim=512, ): super().__init__() assert kernel_size % 2 == 1 assert hidden_dim % 2 == 0 self.input_dim = input_dim self.hidden_dim = hidden_dim self.local_dim = local_dim self.global_dim = global_dim self.start = nn.Conv1d(input_dim, hidden_dim, 1) if local_dim is not None: self.local_norm = nn.InstanceNorm1d(local_dim) self.layers = nn.ModuleList( [ WNLayer( hidden_dim=hidden_dim, local_dim=local_dim, global_dim=global_dim, kernel_size=kernel_size, dilation=2 ** (i % dilation_cycle), ) for i in range(n_layers) ] ) self.end = nn.Conv1d(hidden_dim, output_dim, 1) def forward(self, z, l=None, g=None): """ Args: z: input (b c t) l: local condition (b c t) g: global condition (b d) """ z = self.start(z) if l is not None: l = self.local_norm(l) # Skips s_list = [] for layer in self.layers: z, s = layer(z, l, g) s_list.append(s) s_list = torch.stack(s_list, dim=0).sum(dim=0) s_list = s_list / math.sqrt(len(self.layers)) o = self.end(s_list) return o def summarize(self, length=100): from ptflops import get_model_complexity_info x = torch.randn(1, self.input_dim, length) macs, params = get_model_complexity_info( self, (self.input_dim, length), as_strings=True, print_per_layer_stat=True, verbose=True, ) print(f"Input shape: {x.shape}") print(f"Computational complexity: {macs}") print(f"Number of parameters: {params}") if __name__ == "__main__": model = WN(input_dim=64, output_dim=64) model.summarize()