jadechoghari's picture
Update audioldm_train/modules/phoneme_encoder/encoder.py
43c92df verified
raw
history blame
No virus
1.63 kB
import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
import qa_mdt.audioldm_train.modules.phoneme_encoder.commons as commons
import qa_mdt.audioldm_train.modules.phoneme_encoder.attentions as attentions
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab,
out_channels=192,
hidden_channels=192,
filter_channels=768,
n_heads=2,
n_layers=6,
kernel_size=3,
p_dropout=0.1,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask