# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn import torch as th import numpy as np import logging from .vgg import VGGLossMasked logger = logging.getLogger("dva.{__name__}") class DCTLoss(nn.Module): def __init__(self, weights): super().__init__() self.weights = weights def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} target = inputs['gt'] recon = preds['recon'] posterior = preds['posterior'] fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1))) fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1))) loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon)) loss_recon_l1 = th.mean(th.abs(target - recon)) loss_kl = posterior.kl().mean() loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl) loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl loss_dict["loss_total"] = loss_total return loss_total, loss_dict class VAESepL2Loss(nn.Module): def __init__(self, weights): super().__init__() self.weights = weights def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} target = inputs['gt'] recon = preds['recon'] posterior = preds['posterior'] recon_diff = (target - recon) ** 2 loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) loss_kl = posterior.kl().mean() loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 if "kl" in self.weights: loss_total += self.weights.kl * loss_kl loss_dict["loss_total"] = loss_total return loss_total, loss_dict class VAESepLoss(nn.Module): def __init__(self, weights): super().__init__() self.weights = weights def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} target = inputs['gt'] recon = preds['recon'] posterior = preds['posterior'] recon_diff = th.abs(target - recon) loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) loss_kl = posterior.kl().mean() loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 if "kl" in self.weights: loss_total += self.weights.kl * loss_kl loss_dict["loss_total"] = loss_total return loss_total, loss_dict class VAELoss(nn.Module): def __init__(self, weights): super().__init__() self.weights = weights def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} target = inputs['gt'] recon = preds['recon'] posterior = preds['posterior'] loss_recon_l1 = th.mean(th.abs(target - recon)) loss_kl = posterior.kl().mean() loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl) loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl loss_dict["loss_total"] = loss_total return loss_total, loss_dict class PrimSDFLoss(nn.Module): def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000): super().__init__() self.weights = weights self.shape_opt_steps = shape_opt_steps self.tex_opt_steps = tex_opt_steps def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} if iteration < self.shape_opt_steps: target_sdf = inputs['sdf'] sdf = preds['sdf'] loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf)) loss_dict.update(loss_sdf_l1=loss_sdf_l1) loss_total = self.weights.sdf_l1 * loss_sdf_l1 prim_scale = preds["prim_scale"] # we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube if "vol_sum" in self.weights: loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1)) loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) loss_total += self.weights.vol_sum * loss_prim_vol_sum if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps: target_tex = inputs['tex'] tex = preds['tex'] loss_tex_l1 = th.mean(th.abs(tex - target_tex)) loss_dict.update(loss_tex_l1=loss_tex_l1) loss_total = ( self.weights.rgb_l1 * loss_tex_l1 ) if "mat_l1" in self.weights: target_mat = inputs['mat'] mat = preds['mat'] loss_mat_l1 = th.mean(th.abs(mat - target_mat)) loss_dict.update(loss_mat_l1=loss_mat_l1) loss_total += self.weights.mat_l1 * loss_mat_l1 if "grad_l2" in self.weights: loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2) loss_total += self.weights.grad_l2 * loss_grad_l2 loss_dict.update(loss_grad_l2=loss_grad_l2) loss_dict["loss_total"] = loss_total return loss_total, loss_dict class TotalMVPLoss(nn.Module): def __init__(self, weights, assets=None): super().__init__() self.weights = weights if "vgg" in self.weights: self.vgg_loss = VGGLossMasked() def forward(self, inputs, preds, iteration=None): loss_dict = {"loss_total": 0.0} B = inputs["image"].shape # rgb target_rgb = inputs["image"].permute(0, 2, 3, 1) # removing the mask target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis] rgb = preds["rgb"] loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0) loss_dict.update(loss_rgb_mse=loss_rgb_mse) alpha = preds["alpha"] # mask loss target_mask = inputs["image_mask"][:, 0].to(th.float32) loss_mask_mae = th.mean((target_mask - alpha).abs()) loss_dict.update(loss_mask_mae=loss_mask_mae) B = alpha.shape[0] # beta prior on opacity loss_alpha_prior = th.mean( th.log(0.1 + alpha.reshape(B, -1)) + th.log(0.1 + 1.0 - alpha.reshape(B, -1)) - -2.20727 ) loss_dict.update(loss_alpha_prior=loss_alpha_prior) prim_scale = preds["prim_scale"] loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1)) loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) loss_total = ( self.weights.rgb_mse * loss_rgb_mse + self.weights.mask_mae * loss_mask_mae + self.weights.alpha_prior * loss_alpha_prior + self.weights.prim_vol_sum * loss_prim_vol_sum ) if "embs_l2" in self.weights: loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1)) loss_total += self.weights.embs_l2 * loss_embs_l2 loss_dict.update(loss_embs_l2=loss_embs_l2) if "vgg" in self.weights: loss_vgg = self.vgg_loss( rgb.permute(0, 3, 1, 2), target_rgb.permute(0, 3, 1, 2), inputs["image_mask"], ) loss_total += self.weights.vgg * loss_vgg loss_dict.update(loss_vgg=loss_vgg) if "prim_scale_var" in self.weights: log_prim_scale = th.log(prim_scale) # NOTE: should we detach this? log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True) loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0) loss_total += self.weights.prim_scale_var * loss_prim_scale_var loss_dict.update(loss_prim_scale_var=loss_prim_scale_var) loss_dict["loss_total"] = loss_total return loss_total, loss_dict def process_losses(loss_dict, reduce=True, detach=True): """Preprocess the dict of losses outputs.""" result = { k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_") } if detach: result = {k: v.detach() for k, v in result.items()} if reduce: result = {k: float(v.mean().item()) for k, v in result.items()} return result