File size: 5,913 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from torch import nn

from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformerBlock
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer


class RelativePositionTransformerEncoder(nn.Module):
    """Speedy speech encoder built on Transformer with Relative Position encoding.

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = ResidualConv1dBNBlock(
            in_channels,
            hidden_channels,
            hidden_channels,
            kernel_size=5,
            num_res_blocks=3,
            num_conv_blocks=1,
            dilations=[1, 1, 1],
        )
        self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params)

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.rel_pos_transformer(o, x_mask)
        return o


class ResidualConv1dBNEncoder(nn.Module):
    """Residual Convolutional Encoder as in the original Speedy Speech paper

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU())
        self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params)

        self.postnet = nn.Sequential(
            *[
                nn.Conv1d(hidden_channels, hidden_channels, 1),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_channels),
                nn.Conv1d(hidden_channels, out_channels, 1),
            ]
        )

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.res_conv_block(o, x_mask)
        o = self.postnet(o + x) * x_mask
        return o * x_mask


class Encoder(nn.Module):
    # pylint: disable=dangerous-default-value
    """Factory class for Speedy Speech encoder enables different encoder types internally.

    Args:
        num_chars (int): number of characters.
        out_channels (int): number of output channels.
        in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
        encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
        encoder_params (dict): model parameters for specified encoder type.
        c_in_channels (int): number of channels for conditional input.

    Note:
        Default encoder_params to be set in config.json...

        ```python
        # for 'relative_position_transformer'
        encoder_params={
            'hidden_channels_ffn': 128,
            'num_heads': 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 6,
            "rel_attn_window_size": 4,
            "input_length": None
        },

        # for 'residual_conv_bn'
        encoder_params = {
            "kernel_size": 4,
            "dilations": 4 * [1, 2, 4] + [1],
            "num_conv_blocks": 2,
            "num_res_blocks": 13
        }

        # for 'fftransformer'
        encoder_params = {
            "hidden_channels_ffn": 1024 ,
            "num_heads": 2,
            "num_layers": 6,
            "dropout_p": 0.1
        }
        ```
    """

    def __init__(
        self,
        in_hidden_channels,
        out_channels,
        encoder_type="residual_conv_bn",
        encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
        c_in_channels=0,
    ):
        super().__init__()
        self.out_channels = out_channels
        self.in_channels = in_hidden_channels
        self.hidden_channels = in_hidden_channels
        self.encoder_type = encoder_type
        self.c_in_channels = c_in_channels

        # init encoder
        if encoder_type.lower() == "relative_position_transformer":
            # text encoder
            # pylint: disable=unexpected-keyword-arg
            self.encoder = RelativePositionTransformerEncoder(
                in_hidden_channels, out_channels, in_hidden_channels, encoder_params
            )
        elif encoder_type.lower() == "residual_conv_bn":
            self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
        elif encoder_type.lower() == "fftransformer":
            assert (
                in_hidden_channels == out_channels
            ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
            # pylint: disable=unexpected-keyword-arg
            self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
        else:
            raise NotImplementedError(" [!] unknown encoder type.")

    def forward(self, x, x_mask, g=None):  # pylint: disable=unused-argument
        """
        Shapes:
            x: [B, C, T]
            x_mask: [B, 1, T]
            g: [B, C, 1]
        """
        o = self.encoder(x, x_mask)
        return o * x_mask