import torch import torch.nn.functional as F from modules.tts.fs2_orig import FastSpeech2Orig from tasks.tts.dataset_utils import FastSpeechDataset from tasks.tts.fs import FastSpeechTask from utils.commons.dataset_utils import collate_1d, collate_2d from utils.commons.hparams import hparams class FastSpeech2OrigDataset(FastSpeechDataset): def __init__(self, prefix, shuffle=False, items=None, data_dir=None): super().__init__(prefix, shuffle, items, data_dir) self.pitch_type = hparams.get('pitch_type') def __getitem__(self, index): sample = super().__getitem__(index) item = self._get_item(index) hparams = self.hparams mel = sample['mel'] T = mel.shape[0] sample['energy'] = (mel.exp() ** 2).sum(-1).sqrt() if hparams['use_pitch_embed'] and self.pitch_type == 'cwt': cwt_spec = torch.Tensor(item['cwt_spec'])[:T] f0_mean = item.get('f0_mean', item.get('cwt_mean')) f0_std = item.get('f0_std', item.get('cwt_std')) sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std}) return sample def collater(self, samples): if len(samples) == 0: return {} batch = super().collater(samples) if hparams['use_pitch_embed']: energy = collate_1d([s['energy'] for s in samples], 0.0) else: energy = None batch.update({'energy': energy}) if self.pitch_type == 'cwt': cwt_spec = collate_2d([s['cwt_spec'] for s in samples]) f0_mean = torch.Tensor([s['f0_mean'] for s in samples]) f0_std = torch.Tensor([s['f0_std'] for s in samples]) batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std}) return batch class FastSpeech2OrigTask(FastSpeechTask): def __init__(self): super(FastSpeech2OrigTask, self).__init__() self.dataset_cls = FastSpeech2OrigDataset def build_tts_model(self): dict_size = len(self.token_encoder) self.model = FastSpeech2Orig(dict_size, hparams) def run_model(self, sample, infer=False, *args, **kwargs): txt_tokens = sample['txt_tokens'] # [B, T_t] spk_embed = sample.get('spk_embed') spk_id = sample.get('spk_ids') if not infer: target = sample['mels'] # [B, T_s, 80] mel2ph = sample['mel2ph'] # [B, T_s] f0 = sample.get('f0') uv = sample.get('uv') energy = sample.get('energy') output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, f0=f0, uv=uv, energy=energy, infer=False) losses = {} self.add_mel_loss(output['mel_out'], target, losses) self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) if hparams['use_pitch_embed']: self.add_pitch_loss(output, sample, losses) if hparams['use_energy_embed']: self.add_energy_loss(output, sample, losses) return losses, output else: mel2ph, uv, f0, energy = None, None, None, None use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0']) use_gt_energy = kwargs.get('infer_use_gt_energy', hparams['use_gt_energy']) if use_gt_dur: mel2ph = sample['mel2ph'] if use_gt_f0: f0 = sample['f0'] uv = sample['uv'] if use_gt_energy: energy = sample['energy'] output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, f0=f0, uv=uv, energy=energy, infer=True) return output def add_pitch_loss(self, output, sample, losses): if hparams['pitch_type'] == 'cwt': cwt_spec = sample[f'cwt_spec'] f0_mean = sample['f0_mean'] uv = sample['uv'] mel2ph = sample['mel2ph'] f0_std = sample['f0_std'] cwt_pred = output['cwt'][:, :, :10] f0_mean_pred = output['f0_mean'] f0_std_pred = output['f0_std'] nonpadding = (mel2ph != 0).float() losses['C'] = F.l1_loss(cwt_pred, cwt_spec) * hparams['lambda_f0'] if hparams['use_uv']: assert output['cwt'].shape[-1] == 11 uv_pred = output['cwt'][:, :, -1] losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') * nonpadding).sum() / nonpadding.sum() * hparams['lambda_uv'] losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0'] losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0'] else: super(FastSpeech2OrigTask, self).add_pitch_loss(output, sample, losses) def add_energy_loss(self, output, sample, losses): energy_pred, energy = output['energy_pred'], sample['energy'] nonpadding = (energy != 0).float() loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum() loss = loss * hparams['lambda_energy'] losses['e'] = loss