3DTopia-XL / dva /losses.py
FrozenBurning
single view to 3D init release
81ecb2b
raw
history blame
No virus
9.14 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch as th
import numpy as np
import logging
from .vgg import VGGLossMasked
logger = logging.getLogger("dva.{__name__}")
class DCTLoss(nn.Module):
def __init__(self, weights):
super().__init__()
self.weights = weights
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
target = inputs['gt']
recon = preds['recon']
posterior = preds['posterior']
fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1)))
fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1)))
loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon))
loss_recon_l1 = th.mean(th.abs(target - recon))
loss_kl = posterior.kl().mean()
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl)
loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
class VAESepL2Loss(nn.Module):
def __init__(self, weights):
super().__init__()
self.weights = weights
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
target = inputs['gt']
recon = preds['recon']
posterior = preds['posterior']
recon_diff = (target - recon) ** 2
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
loss_kl = posterior.kl().mean()
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
if "kl" in self.weights:
loss_total += self.weights.kl * loss_kl
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
class VAESepLoss(nn.Module):
def __init__(self, weights):
super().__init__()
self.weights = weights
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
target = inputs['gt']
recon = preds['recon']
posterior = preds['posterior']
recon_diff = th.abs(target - recon)
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
loss_kl = posterior.kl().mean()
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
if "kl" in self.weights:
loss_total += self.weights.kl * loss_kl
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
class VAELoss(nn.Module):
def __init__(self, weights):
super().__init__()
self.weights = weights
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
target = inputs['gt']
recon = preds['recon']
posterior = preds['posterior']
loss_recon_l1 = th.mean(th.abs(target - recon))
loss_kl = posterior.kl().mean()
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl)
loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
class PrimSDFLoss(nn.Module):
def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000):
super().__init__()
self.weights = weights
self.shape_opt_steps = shape_opt_steps
self.tex_opt_steps = tex_opt_steps
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
if iteration < self.shape_opt_steps:
target_sdf = inputs['sdf']
sdf = preds['sdf']
loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf))
loss_dict.update(loss_sdf_l1=loss_sdf_l1)
loss_total = self.weights.sdf_l1 * loss_sdf_l1
prim_scale = preds["prim_scale"]
# we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube
if "vol_sum" in self.weights:
loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1))
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
loss_total += self.weights.vol_sum * loss_prim_vol_sum
if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps:
target_tex = inputs['tex']
tex = preds['tex']
loss_tex_l1 = th.mean(th.abs(tex - target_tex))
loss_dict.update(loss_tex_l1=loss_tex_l1)
loss_total = (
self.weights.rgb_l1 * loss_tex_l1
)
if "mat_l1" in self.weights:
target_mat = inputs['mat']
mat = preds['mat']
loss_mat_l1 = th.mean(th.abs(mat - target_mat))
loss_dict.update(loss_mat_l1=loss_mat_l1)
loss_total += self.weights.mat_l1 * loss_mat_l1
if "grad_l2" in self.weights:
loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2)
loss_total += self.weights.grad_l2 * loss_grad_l2
loss_dict.update(loss_grad_l2=loss_grad_l2)
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
class TotalMVPLoss(nn.Module):
def __init__(self, weights, assets=None):
super().__init__()
self.weights = weights
if "vgg" in self.weights:
self.vgg_loss = VGGLossMasked()
def forward(self, inputs, preds, iteration=None):
loss_dict = {"loss_total": 0.0}
B = inputs["image"].shape
# rgb
target_rgb = inputs["image"].permute(0, 2, 3, 1)
# removing the mask
target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis]
rgb = preds["rgb"]
loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0)
loss_dict.update(loss_rgb_mse=loss_rgb_mse)
alpha = preds["alpha"]
# mask loss
target_mask = inputs["image_mask"][:, 0].to(th.float32)
loss_mask_mae = th.mean((target_mask - alpha).abs())
loss_dict.update(loss_mask_mae=loss_mask_mae)
B = alpha.shape[0]
# beta prior on opacity
loss_alpha_prior = th.mean(
th.log(0.1 + alpha.reshape(B, -1))
+ th.log(0.1 + 1.0 - alpha.reshape(B, -1))
- -2.20727
)
loss_dict.update(loss_alpha_prior=loss_alpha_prior)
prim_scale = preds["prim_scale"]
loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1))
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
loss_total = (
self.weights.rgb_mse * loss_rgb_mse
+ self.weights.mask_mae * loss_mask_mae
+ self.weights.alpha_prior * loss_alpha_prior
+ self.weights.prim_vol_sum * loss_prim_vol_sum
)
if "embs_l2" in self.weights:
loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1))
loss_total += self.weights.embs_l2 * loss_embs_l2
loss_dict.update(loss_embs_l2=loss_embs_l2)
if "vgg" in self.weights:
loss_vgg = self.vgg_loss(
rgb.permute(0, 3, 1, 2),
target_rgb.permute(0, 3, 1, 2),
inputs["image_mask"],
)
loss_total += self.weights.vgg * loss_vgg
loss_dict.update(loss_vgg=loss_vgg)
if "prim_scale_var" in self.weights:
log_prim_scale = th.log(prim_scale)
# NOTE: should we detach this?
log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True)
loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0)
loss_total += self.weights.prim_scale_var * loss_prim_scale_var
loss_dict.update(loss_prim_scale_var=loss_prim_scale_var)
loss_dict["loss_total"] = loss_total
return loss_total, loss_dict
def process_losses(loss_dict, reduce=True, detach=True):
"""Preprocess the dict of losses outputs."""
result = {
k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")
}
if detach:
result = {k: v.detach() for k, v in result.items()}
if reduce:
result = {k: float(v.mean().item()) for k, v in result.items()}
return result