File size: 4,594 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from torch import nn


class ZeroTemporalPad(nn.Module):
    """Pad sequences to equal lentgh in the temporal dimension"""

    def __init__(self, kernel_size, dilation):
        super().__init__()
        total_pad = dilation * (kernel_size - 1)
        begin = total_pad // 2
        end = total_pad - begin
        self.pad_layer = nn.ZeroPad2d((0, 0, begin, end))

    def forward(self, x):
        return self.pad_layer(x)


class Conv1dBN(nn.Module):
    """1d convolutional with batch norm.
    conv1d -> relu -> BN blocks.

    Note:
        Batch normalization is applied after ReLU regarding the original implementation.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        kernel_size (int): kernel size for convolutional filters.
        dilation (int): dilation for convolution layers.
    """

    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        padding = dilation * (kernel_size - 1)
        pad_s = padding // 2
        pad_e = padding - pad_s
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
        self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0))  # uneven left and right padding
        self.norm = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        o = self.conv1d(x)
        o = self.pad(o)
        o = nn.functional.relu(o)
        o = self.norm(o)
        return o


class Conv1dBNBlock(nn.Module):
    """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of inner convolution channels.
        kernel_size (int): kernel size for convolutional filters.
        dilation (int): dilation for convolution layers.
        num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2):
        super().__init__()
        self.conv_bn_blocks = []
        for idx in range(num_conv_blocks):
            layer = Conv1dBN(
                in_channels if idx == 0 else hidden_channels,
                out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
                kernel_size,
                dilation,
            )
            self.conv_bn_blocks.append(layer)
        self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks)

    def forward(self, x):
        """
        Shapes:
            x: (B, D, T)
        """
        return self.conv_bn_blocks(x)


class ResidualConv1dBNBlock(nn.Module):
    """Residual Convolutional Blocks with BN
    Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected
    with residual connections.

    conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks'
    residuak_conv_block =  (x -> conv_block ->  + ->) x 'num_res_blocks'
                            ' - - - - - - - - - ^
    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of inner convolution channels.
        kernel_size (int): kernel size for convolutional filters.
        dilations (list): dilations for each convolution layer.
        num_res_blocks (int, optional): number of residual blocks. Defaults to 13.
        num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2.
    """

    def __init__(
        self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2
    ):
        super().__init__()
        assert len(dilations) == num_res_blocks
        self.res_blocks = nn.ModuleList()
        for idx, dilation in enumerate(dilations):
            block = Conv1dBNBlock(
                in_channels if idx == 0 else hidden_channels,
                out_channels if (idx + 1) == len(dilations) else hidden_channels,
                hidden_channels,
                kernel_size,
                dilation,
                num_conv_blocks,
            )
            self.res_blocks.append(block)

    def forward(self, x, x_mask=None):
        if x_mask is None:
            x_mask = 1.0
        o = x * x_mask
        for block in self.res_blocks:
            res = o
            o = block(o)
            o = o + res
            if x_mask is not None:
                o = o * x_mask
        return o