import torch from torch import nn from .. import AudioSignal class L1Loss(nn.L1Loss): """L1 Loss between AudioSignals. Defaults to comparing ``audio_data``, but any attribute of an AudioSignal can be used. Parameters ---------- attribute : str, optional Attribute of signal to compare, defaults to ``audio_data``. weight : float, optional Weight of this loss, defaults to 1.0. """ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): self.attribute = attribute self.weight = weight super().__init__(**kwargs) def forward(self, x: AudioSignal, y: AudioSignal): """ Parameters ---------- x : AudioSignal Estimate AudioSignal y : AudioSignal Reference AudioSignal Returns ------- torch.Tensor L1 loss between AudioSignal attributes. """ if isinstance(x, AudioSignal): x = getattr(x, self.attribute) y = getattr(y, self.attribute) return super().forward(x, y) class SISDRLoss(nn.Module): """ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch of estimated and reference audio signals or aligned features. Parameters ---------- scaling : int, optional Whether to use scale-invariant (True) or signal-to-noise ratio (False), by default True reduction : str, optional How to reduce across the batch (either 'mean', 'sum', or none).], by default ' mean' zero_mean : int, optional Zero mean the references and estimates before computing the loss, by default True clip_min : int, optional The minimum possible loss value. Helps network to not focus on making already good examples better, by default None weight : float, optional Weight of this loss, defaults to 1.0. """ def __init__( self, scaling: int = True, reduction: str = "mean", zero_mean: int = True, clip_min: int = None, weight: float = 1.0, ): self.scaling = scaling self.reduction = reduction self.zero_mean = zero_mean self.clip_min = clip_min self.weight = weight super().__init__() def forward(self, x: AudioSignal, y: AudioSignal): eps = 1e-8 # nb, nc, nt if isinstance(x, AudioSignal): references = x.audio_data estimates = y.audio_data else: references = x estimates = y nb = references.shape[0] references = references.reshape(nb, 1, -1).permute(0, 2, 1) estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) # samples now on axis 1 if self.zero_mean: mean_reference = references.mean(dim=1, keepdim=True) mean_estimate = estimates.mean(dim=1, keepdim=True) else: mean_reference = 0 mean_estimate = 0 _references = references - mean_reference _estimates = estimates - mean_estimate references_projection = (_references**2).sum(dim=-2) + eps references_on_estimates = (_estimates * _references).sum(dim=-2) + eps scale = ( (references_on_estimates / references_projection).unsqueeze(1) if self.scaling else 1 ) e_true = scale * _references e_res = _estimates - e_true signal = (e_true**2).sum(dim=1) noise = (e_res**2).sum(dim=1) sdr = -10 * torch.log10(signal / noise + eps) if self.clip_min is not None: sdr = torch.clamp(sdr, min=self.clip_min) if self.reduction == "mean": sdr = sdr.mean() elif self.reduction == "sum": sdr = sdr.sum() return sdr