|
import time |
|
import os |
|
import random |
|
import numpy as np |
|
import torch |
|
import torch.utils.data |
|
|
|
import modules.commons as commons |
|
import utils |
|
from modules.mel_processing import spectrogram_torch, spec_to_mel_torch, spectrogram_torch |
|
from utils import load_wav_to_torch, load_filepaths_and_text |
|
|
|
|
|
|
|
|
|
"""Multi speaker version""" |
|
|
|
|
|
class TextAudioSpeakerLoader(torch.utils.data.Dataset): |
|
""" |
|
1) loads audio, speaker_id, text pairs |
|
2) normalizes text and converts them to sequences of integers |
|
3) computes spectrograms from audio files. |
|
""" |
|
|
|
def __init__(self, audiopaths, hparams, all_in_mem: bool = False, vol_aug: bool = True): |
|
self.audiopaths = load_filepaths_and_text(audiopaths) |
|
self.hparams = hparams |
|
self.max_wav_value = hparams.data.max_wav_value |
|
self.sampling_rate = hparams.data.sampling_rate |
|
self.filter_length = hparams.data.filter_length |
|
self.hop_length = hparams.data.hop_length |
|
self.win_length = hparams.data.win_length |
|
self.unit_interpolate_mode = hparams.data.unit_interpolate_mode |
|
self.sampling_rate = hparams.data.sampling_rate |
|
self.use_sr = hparams.train.use_sr |
|
self.spec_len = hparams.train.max_speclen |
|
self.spk_map = hparams.spk |
|
self.vol_emb = hparams.model.vol_embedding |
|
self.vol_aug = hparams.train.vol_aug and vol_aug |
|
random.seed(1234) |
|
random.shuffle(self.audiopaths) |
|
|
|
self.all_in_mem = all_in_mem |
|
if self.all_in_mem: |
|
self.cache = [self.get_audio(p[0]) for p in self.audiopaths] |
|
|
|
def get_audio(self, filename): |
|
filename = filename.replace("\\", "/") |
|
audio, sampling_rate = load_wav_to_torch(filename) |
|
if sampling_rate != self.sampling_rate: |
|
raise ValueError("{} SR doesn't match target {} SR".format( |
|
sampling_rate, self.sampling_rate)) |
|
audio_norm = audio / self.max_wav_value |
|
audio_norm = audio_norm.unsqueeze(0) |
|
spec_filename = filename.replace(".wav", ".spec.pt") |
|
|
|
|
|
if os.path.exists(spec_filename): |
|
spec = torch.load(spec_filename) |
|
else: |
|
spec = spectrogram_torch(audio_norm, self.filter_length, |
|
self.sampling_rate, self.hop_length, self.win_length, |
|
center=False) |
|
spec = torch.squeeze(spec, 0) |
|
torch.save(spec, spec_filename) |
|
|
|
spk = filename.split("/")[-2] |
|
spk = torch.LongTensor([self.spk_map[spk]]) |
|
|
|
f0, uv = np.load(filename + ".f0.npy",allow_pickle=True) |
|
|
|
f0 = torch.FloatTensor(np.array(f0,dtype=float)) |
|
uv = torch.FloatTensor(np.array(uv,dtype=float)) |
|
|
|
c = torch.load(filename+ ".soft.pt") |
|
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0], mode=self.unit_interpolate_mode) |
|
if self.vol_emb: |
|
volume_path = filename + ".vol.npy" |
|
volume = np.load(volume_path) |
|
volume = torch.from_numpy(volume).float() |
|
else: |
|
volume = None |
|
|
|
lmin = min(c.size(-1), spec.size(-1)) |
|
assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename) |
|
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length |
|
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin] |
|
audio_norm = audio_norm[:, :lmin * self.hop_length] |
|
if volume!= None: |
|
volume = volume[:lmin] |
|
return c, f0, spec, audio_norm, spk, uv, volume |
|
|
|
def random_slice(self, c, f0, spec, audio_norm, spk, uv, volume): |
|
|
|
|
|
|
|
|
|
if random.choice([True, False]) and self.vol_aug and volume!=None: |
|
max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5 |
|
max_shift = min(1, np.log10(1/max_amp)) |
|
log10_vol_shift = random.uniform(-1, max_shift) |
|
audio_norm = audio_norm * (10 ** log10_vol_shift) |
|
volume = volume * (10 ** log10_vol_shift) |
|
spec = spectrogram_torch(audio_norm, |
|
self.hparams.data.filter_length, |
|
self.hparams.data.sampling_rate, |
|
self.hparams.data.hop_length, |
|
self.hparams.data.win_length, |
|
center=False)[0] |
|
|
|
if spec.shape[1] > 800: |
|
start = random.randint(0, spec.shape[1]-800) |
|
end = start + 790 |
|
spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end] |
|
audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length] |
|
if volume !=None: |
|
volume = volume[start:end] |
|
return c, f0, spec, audio_norm, spk, uv,volume |
|
|
|
def __getitem__(self, index): |
|
if self.all_in_mem: |
|
return self.random_slice(*self.cache[index]) |
|
else: |
|
return self.random_slice(*self.get_audio(self.audiopaths[index][0])) |
|
|
|
def __len__(self): |
|
return len(self.audiopaths) |
|
|
|
|
|
class TextAudioCollate: |
|
|
|
def __call__(self, batch): |
|
batch = [b for b in batch if b is not None] |
|
|
|
input_lengths, ids_sorted_decreasing = torch.sort( |
|
torch.LongTensor([x[0].shape[1] for x in batch]), |
|
dim=0, descending=True) |
|
|
|
max_c_len = max([x[0].size(1) for x in batch]) |
|
max_wav_len = max([x[3].size(1) for x in batch]) |
|
|
|
lengths = torch.LongTensor(len(batch)) |
|
|
|
c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len) |
|
f0_padded = torch.FloatTensor(len(batch), max_c_len) |
|
spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len) |
|
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) |
|
spkids = torch.LongTensor(len(batch), 1) |
|
uv_padded = torch.FloatTensor(len(batch), max_c_len) |
|
volume_padded = torch.FloatTensor(len(batch), max_c_len) |
|
|
|
c_padded.zero_() |
|
spec_padded.zero_() |
|
f0_padded.zero_() |
|
wav_padded.zero_() |
|
uv_padded.zero_() |
|
volume_padded.zero_() |
|
|
|
for i in range(len(ids_sorted_decreasing)): |
|
row = batch[ids_sorted_decreasing[i]] |
|
|
|
c = row[0] |
|
c_padded[i, :, :c.size(1)] = c |
|
lengths[i] = c.size(1) |
|
|
|
f0 = row[1] |
|
f0_padded[i, :f0.size(0)] = f0 |
|
|
|
spec = row[2] |
|
spec_padded[i, :, :spec.size(1)] = spec |
|
|
|
wav = row[3] |
|
wav_padded[i, :, :wav.size(1)] = wav |
|
|
|
spkids[i, 0] = row[4] |
|
|
|
uv = row[5] |
|
uv_padded[i, :uv.size(0)] = uv |
|
volume = row[6] |
|
if volume != None: |
|
volume_padded[i, :volume.size(0)] = volume |
|
else : |
|
volume_padded = None |
|
return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded, volume_padded |
|
|