# Author: Bingxin Ke # Last modified: 2024-02-22 import torch def get_loss(loss_name, **kwargs): if "silog_mse" == loss_name: criterion = SILogMSELoss(**kwargs) elif "silog_rmse" == loss_name: criterion = SILogRMSELoss(**kwargs) elif "mse_loss" == loss_name: criterion = torch.nn.MSELoss(**kwargs) elif "l1_loss" == loss_name: criterion = torch.nn.L1Loss(**kwargs) elif "l1_loss_with_mask" == loss_name: criterion = L1LossWithMask(**kwargs) elif "mean_abs_rel" == loss_name: criterion = MeanAbsRelLoss() else: raise NotImplementedError return criterion class L1LossWithMask: def __init__(self, batch_reduction=False): self.batch_reduction = batch_reduction def __call__(self, depth_pred, depth_gt, valid_mask=None): diff = depth_pred - depth_gt if valid_mask is not None: diff[~valid_mask] = 0 n = valid_mask.sum((-1, -2)) else: n = depth_gt.shape[-2] * depth_gt.shape[-1] loss = torch.sum(torch.abs(diff)) / n if self.batch_reduction: loss = loss.mean() return loss class MeanAbsRelLoss: def __init__(self) -> None: # super().__init__() pass def __call__(self, pred, gt): diff = pred - gt rel_abs = torch.abs(diff / gt) loss = torch.mean(rel_abs, dim=0) return loss class SILogMSELoss: def __init__(self, lamb, log_pred=True, batch_reduction=True): """Scale Invariant Log MSE Loss Args: lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred """ super(SILogMSELoss, self).__init__() self.lamb = lamb self.pred_in_log = log_pred self.batch_reduction = batch_reduction def __call__(self, depth_pred, depth_gt, valid_mask=None): log_depth_pred = ( depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) ) log_depth_gt = torch.log(depth_gt) diff = log_depth_pred - log_depth_gt if valid_mask is not None: diff[~valid_mask] = 0 n = valid_mask.sum((-1, -2)) else: n = depth_gt.shape[-2] * depth_gt.shape[-1] diff2 = torch.pow(diff, 2) first_term = torch.sum(diff2, (-1, -2)) / n second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) loss = first_term - second_term if self.batch_reduction: loss = loss.mean() return loss class SILogRMSELoss: def __init__(self, lamb, alpha, log_pred=True): """Scale Invariant Log RMSE Loss Args: lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss alpha: log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred """ super(SILogRMSELoss, self).__init__() self.lamb = lamb self.alpha = alpha self.pred_in_log = log_pred def __call__(self, depth_pred, depth_gt, valid_mask): log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) log_depth_gt = torch.log(depth_gt) # borrowed from https://github.com/aliyun/NeWCRFs # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha diff = log_depth_pred - log_depth_gt if valid_mask is not None: diff[~valid_mask] = 0 n = valid_mask.sum((-1, -2)) else: n = depth_gt.shape[-2] * depth_gt.shape[-1] diff2 = torch.pow(diff, 2) first_term = torch.sum(diff2, (-1, -2)) / n second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) loss = torch.sqrt(first_term - second_term).mean() * self.alpha return loss