import typing from typing import List import numpy as np from torch import nn from .. import AudioSignal from .. import STFTParams class MultiScaleSTFTLoss(nn.Module): """Computes the multi-scale STFT loss from [1]. Parameters ---------- window_lengths : List[int], optional Length of each window of each STFT, by default [2048, 512] loss_fn : typing.Callable, optional How to compare each loss, by default nn.L1Loss() clamp_eps : float, optional Clamp on the log magnitude, below, by default 1e-5 mag_weight : float, optional Weight of raw magnitude portion of loss, by default 1.0 log_weight : float, optional Weight of log magnitude portion of loss, by default 1.0 pow : float, optional Power to raise magnitude to before taking log, by default 2.0 weight : float, optional Weight of this loss, by default 1.0 match_stride : bool, optional Whether to match the stride of convolutional layers, by default False References ---------- 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. "DDSP: Differentiable Digital Signal Processing." International Conference on Learning Representations. 2019. """ def __init__( self, window_lengths: List[int] = [2048, 512], loss_fn: typing.Callable = nn.L1Loss(), clamp_eps: float = 1e-5, mag_weight: float = 1.0, log_weight: float = 1.0, pow: float = 2.0, weight: float = 1.0, match_stride: bool = False, window_type: str = None, ): super().__init__() self.stft_params = [ STFTParams( window_length=w, hop_length=w // 4, match_stride=match_stride, window_type=window_type, ) for w in window_lengths ] self.loss_fn = loss_fn self.log_weight = log_weight self.mag_weight = mag_weight self.clamp_eps = clamp_eps self.weight = weight self.pow = pow def forward(self, x: AudioSignal, y: AudioSignal): """Computes multi-scale STFT between an estimate and a reference signal. Parameters ---------- x : AudioSignal Estimate signal y : AudioSignal Reference signal Returns ------- torch.Tensor Multi-scale STFT loss. """ loss = 0.0 for s in self.stft_params: x.stft(s.window_length, s.hop_length, s.window_type) y.stft(s.window_length, s.hop_length, s.window_type) loss += self.log_weight * self.loss_fn( x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), ) loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) return loss class MelSpectrogramLoss(nn.Module): """Compute distance between mel spectrograms. Can be used in a multi-scale way. Parameters ---------- n_mels : List[int] Number of mels per STFT, by default [150, 80], window_lengths : List[int], optional Length of each window of each STFT, by default [2048, 512] loss_fn : typing.Callable, optional How to compare each loss, by default nn.L1Loss() clamp_eps : float, optional Clamp on the log magnitude, below, by default 1e-5 mag_weight : float, optional Weight of raw magnitude portion of loss, by default 1.0 log_weight : float, optional Weight of log magnitude portion of loss, by default 1.0 pow : float, optional Power to raise magnitude to before taking log, by default 2.0 weight : float, optional Weight of this loss, by default 1.0 match_stride : bool, optional Whether to match the stride of convolutional layers, by default False """ def __init__( self, n_mels: List[int] = [150, 80], window_lengths: List[int] = [2048, 512], loss_fn: typing.Callable = nn.L1Loss(), clamp_eps: float = 1e-5, mag_weight: float = 1.0, log_weight: float = 1.0, pow: float = 2.0, weight: float = 1.0, match_stride: bool = False, mel_fmin: List[float] = [0.0, 0.0], mel_fmax: List[float] = [None, None], window_type: str = None, ): super().__init__() self.stft_params = [ STFTParams( window_length=w, hop_length=w // 4, match_stride=match_stride, window_type=window_type, ) for w in window_lengths ] self.n_mels = n_mels self.loss_fn = loss_fn self.clamp_eps = clamp_eps self.log_weight = log_weight self.mag_weight = mag_weight self.weight = weight self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax self.pow = pow def forward(self, x: AudioSignal, y: AudioSignal): """Computes mel loss between an estimate and a reference signal. Parameters ---------- x : AudioSignal Estimate signal y : AudioSignal Reference signal Returns ------- torch.Tensor Mel loss. """ loss = 0.0 for n_mels, fmin, fmax, s in zip( self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params ): kwargs = { "window_length": s.window_length, "hop_length": s.hop_length, "window_type": s.window_type, } x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) loss += self.log_weight * self.loss_fn( x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), ) loss += self.mag_weight * self.loss_fn(x_mels, y_mels) return loss class PhaseLoss(nn.Module): """Difference between phase spectrograms. Parameters ---------- window_length : int, optional Length of STFT window, by default 2048 hop_length : int, optional Hop length of STFT window, by default 512 weight : float, optional Weight of loss, by default 1.0 """ def __init__( self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 ): super().__init__() self.weight = weight self.stft_params = STFTParams(window_length, hop_length) def forward(self, x: AudioSignal, y: AudioSignal): """Computes phase loss between an estimate and a reference signal. Parameters ---------- x : AudioSignal Estimate signal y : AudioSignal Reference signal Returns ------- torch.Tensor Phase loss. """ s = self.stft_params x.stft(s.window_length, s.hop_length, s.window_type) y.stft(s.window_length, s.hop_length, s.window_type) # Take circular difference diff = x.phase - y.phase diff[diff < -np.pi] += 2 * np.pi diff[diff > np.pi] -= -2 * np.pi # Scale true magnitude to weights in [0, 1] x_min, x_max = x.magnitude.min(), x.magnitude.max() weights = (x.magnitude - x_min) / (x_max - x_min) # Take weighted mean of all phase errors loss = ((weights * diff) ** 2).mean() return loss