# Copyright 2022 The Nerfstudio Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Collection of Losses. """ import torch import torch.nn.functional as F from torch import nn from torchtyping import TensorType from torch.autograd import Variable import numpy as np from math import exp # from nerfstudio.cameras.rays import RaySamples # from nerfstudio.field_components.field_heads import FieldHeadNames L1Loss = nn.L1Loss MSELoss = nn.MSELoss LOSSES = {"L1": L1Loss, "MSE": MSELoss} EPS = 1.0e-7 def outer( t0_starts: TensorType[..., "num_samples_0"], t0_ends: TensorType[..., "num_samples_0"], t1_starts: TensorType[..., "num_samples_1"], t1_ends: TensorType[..., "num_samples_1"], y1: TensorType[..., "num_samples_1"], ) -> TensorType[..., "num_samples_0"]: """Faster version of https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64 Args: t0_starts: start of the interval edges t0_ends: end of the interval edges t1_starts: start of the interval edges t1_ends: end of the interval edges y1: weights """ cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1) idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1 idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1) idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right") idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1) cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1) cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1) y0_outer = cy1_hi - cy1_lo return y0_outer def lossfun_outer( t: TensorType[..., "num_samples+1"], w: TensorType[..., "num_samples"], t_env: TensorType[..., "num_samples+1"], w_env: TensorType[..., "num_samples"], ): """ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80 Args: t: interval edges w: weights t_env: interval edges of the upper bound enveloping historgram w_env: weights that should upper bound the inner (t,w) histogram """ w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env) return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS) def ray_samples_to_sdist(ray_samples): """Convert ray samples to s space""" starts = ray_samples.spacing_starts ends = ray_samples.spacing_ends sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) # (num_rays, num_samples + 1) return sdist def interlevel_loss(weights_list, ray_samples_list): """Calculates the proposal loss in the MipNeRF-360 paper. https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133 """ c = ray_samples_to_sdist(ray_samples_list[-1]).detach() w = weights_list[-1][..., 0].detach() loss_interlevel = 0.0 for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]): sdist = ray_samples_to_sdist(ray_samples) cp = sdist # (num_rays, num_samples + 1) wp = weights[..., 0] # (num_rays, num_samples) loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp)) return loss_interlevel ## zip-NeRF losses def blur_stepfun(x, y, r): x_c = torch.cat([x - r, x + r], dim=-1) x_r, x_idx = torch.sort(x_c, dim=-1) zeros = torch.zeros_like(y[:, :1]) y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r) x_idx = x_idx[:, :-1] y_2 = torch.cat([y_1, -y_1], dim=-1)[ torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx ] y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1) y_r = torch.cat([zeros, y_r], dim=-1) return x_r, y_r def interlevel_loss_zip(weights_list, ray_samples_list): """Calculates the proposal loss in the Zip-NeRF paper.""" c = ray_samples_to_sdist(ray_samples_list[-1]).detach() w = weights_list[-1][..., 0].detach() # 1. normalize w_normalize = w / (c[:, 1:] - c[:, :-1]) loss_interlevel = 0.0 for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]): # 2. step blur with different r x_r, y_r = blur_stepfun(c, w_normalize, r) y_r = torch.clip(y_r, min=0) assert (y_r >= 0.0).all() # 3. accumulate y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1) y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1) # 4 loss sdist = ray_samples_to_sdist(ray_samples) cp = sdist # (num_rays, num_samples + 1) wp = weights[..., 0] # (num_rays, num_samples) # resample inds = torch.searchsorted(x_r, cp, side="right") below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1) above = torch.clamp(inds, 0, x_r.shape[-1] - 1) cdf_g0 = torch.gather(x_r, -1, below) bins_g0 = torch.gather(y_cum, -1, below) cdf_g1 = torch.gather(x_r, -1, above) bins_g1 = torch.gather(y_cum, -1, above) t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) bins = bins_g0 + t * (bins_g1 - bins_g0) w_gt = bins[:, 1:] - bins[:, :-1] # TODO here might be unstable when wp is very small loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5)) return loss_interlevel # Verified def lossfun_distortion(t, w): """ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266 """ ut = (t[..., 1:] + t[..., :-1]) / 2 dut = torch.abs(ut[..., :, None] - ut[..., None, :]) loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 return loss_inter + loss_intra def distortion_loss(weights_list, ray_samples_list): """From mipnerf360""" c = ray_samples_to_sdist(ray_samples_list[-1]) w = weights_list[-1][..., 0] loss = torch.mean(lossfun_distortion(c, w)) return loss # def nerfstudio_distortion_loss( # ray_samples: RaySamples, # densities: TensorType["bs":..., "num_samples", 1] = None, # weights: TensorType["bs":..., "num_samples", 1] = None, # ) -> TensorType["bs":..., 1]: # """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss. # .. math:: # \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty} # \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v} # where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)` # is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`. # Args: # ray_samples: Ray samples to compute loss over # densities: Predicted sample densities # weights: Predicted weights from densities and sample locations # """ # if torch.is_tensor(densities): # assert not torch.is_tensor(weights), "Cannot use both densities and weights" # # Compute the weight at each sample location # weights = ray_samples.get_weights(densities) # if torch.is_tensor(weights): # assert not torch.is_tensor(densities), "Cannot use both densities and weights" # starts = ray_samples.spacing_starts # ends = ray_samples.spacing_ends # assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends" # midpoints = (starts + ends) / 2.0 # (..., num_samples, 1) # loss = ( # weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0]) # ) # (..., num_samples, num_samples) # loss = torch.sum(loss, dim=(-1, -2))[..., None] # (..., num_samples) # loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2) # return loss def orientation_loss( weights: TensorType["bs":..., "num_samples", 1], normals: TensorType["bs":..., "num_samples", 3], viewdirs: TensorType["bs":..., 3], ): """Orientation loss proposed in Ref-NeRF. Loss that encourages that all visible normals are facing towards the camera. """ w = weights n = normals v = viewdirs n_dot_v = (n * v[..., None, :]).sum(axis=-1) return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1) def pred_normal_loss( weights: TensorType["bs":..., "num_samples", 1], normals: TensorType["bs":..., "num_samples", 3], pred_normals: TensorType["bs":..., "num_samples", 3], ): """Loss between normals calculated from density and normals from prediction network.""" return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1) def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor): """normal consistency loss as monosdf Args: normal_pred (torch.Tensor): volume rendered normal normal_gt (torch.Tensor): monocular normal """ normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean() return l1 + cos # copy from MiDaS def compute_scale_and_shift(prediction, target, mask): # system matrix: A = [[a_00, a_01], [a_10, a_11]] a_00 = torch.sum(mask * prediction * prediction, (1, 2)) a_01 = torch.sum(mask * prediction, (1, 2)) a_11 = torch.sum(mask, (1, 2)) # right hand side: b = [b_0, b_1] b_0 = torch.sum(mask * prediction * target, (1, 2)) b_1 = torch.sum(mask * target, (1, 2)) # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b x_0 = torch.zeros_like(b_0) x_1 = torch.zeros_like(b_1) det = a_00 * a_11 - a_01 * a_01 valid = det.nonzero() x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] return x_0, x_1 def reduction_batch_based(image_loss, M): # average of all valid pixels of the batch # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) divisor = torch.sum(M) if divisor == 0: return 0 else: return torch.sum(image_loss) / divisor def reduction_image_based(image_loss, M): # mean of average of valid pixels of an image # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) valid = M.nonzero() image_loss[valid] = image_loss[valid] / M[valid] return torch.mean(image_loss) def mse_loss(prediction, target, mask, reduction=reduction_batch_based): M = torch.sum(mask, (1, 2)) res = prediction - target image_loss = torch.sum(mask * res * res, (1, 2)) return reduction(image_loss, 2 * M) def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): M = torch.sum(mask, (1, 2)) diff = prediction - target diff = torch.mul(mask, diff) grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) grad_x = torch.mul(mask_x, grad_x) grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) grad_y = torch.mul(mask_y, grad_y) image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) return reduction(image_loss, M) class MiDaSMSELoss(nn.Module): def __init__(self, reduction="batch-based"): super().__init__() if reduction == "batch-based": self.__reduction = reduction_batch_based else: self.__reduction = reduction_image_based def forward(self, prediction, target, mask): return mse_loss(prediction, target, mask, reduction=self.__reduction) class GradientLoss(nn.Module): def __init__(self, scales=4, reduction="batch-based"): super().__init__() if reduction == "batch-based": self.__reduction = reduction_batch_based else: self.__reduction = reduction_image_based self.__scales = scales def forward(self, prediction, target, mask): total = 0 for scale in range(self.__scales): step = pow(2, scale) total += gradient_loss( prediction[:, ::step, ::step], target[:, ::step, ::step], mask[:, ::step, ::step], reduction=self.__reduction, ) return total class ScaleAndShiftInvariantLoss(nn.Module): def __init__(self, alpha=0.5, scales=4, reduction="batch-based"): super().__init__() self.__data_loss = MiDaSMSELoss(reduction=reduction) self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) self.__alpha = alpha self.__prediction_ssi = None def forward(self, prediction, target, mask): scale, shift = compute_scale_and_shift(prediction, target, mask) self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) total = self.__data_loss(self.__prediction_ssi, target, mask) if self.__alpha > 0: total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) return total def __get_prediction_ssi(self): return self.__prediction_ssi prediction_ssi = property(__get_prediction_ssi) # end copy # copy from https://github.com/svip-lab/Indoor-SfMLearner/blob/0d682b7ce292484e5e3e2161fc9fc07e2f5ca8d1/layers.py#L218 class SSIM(nn.Module): """Layer to compute the SSIM loss between a pair of images""" def __init__(self, patch_size): super(SSIM, self).__init__() self.mu_x_pool = nn.AvgPool2d(patch_size, 1) self.mu_y_pool = nn.AvgPool2d(patch_size, 1) self.sig_x_pool = nn.AvgPool2d(patch_size, 1) self.sig_y_pool = nn.AvgPool2d(patch_size, 1) self.sig_xy_pool = nn.AvgPool2d(patch_size, 1) self.refl = nn.ReflectionPad2d(patch_size // 2) self.C1 = 0.01**2 self.C2 = 0.03**2 def forward(self, x, y): x = self.refl(x) y = self.refl(y) mu_x = self.mu_x_pool(x) mu_y = self.mu_y_pool(y) sigma_x = self.sig_x_pool(x**2) - mu_x**2 sigma_y = self.sig_y_pool(y**2) - mu_y**2 sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2) return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) # TODO test different losses class NCC(nn.Module): """Layer to compute the normalization cross correlation (NCC) of patches""" def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01): super(NCC, self).__init__() self.patch_size = patch_size self.min_patch_variance = min_patch_variance def forward(self, x, y): # TODO if we use gray image we should do it right after loading the image to save computations # to gray image x = torch.mean(x, dim=1) y = torch.mean(y, dim=1) x_mean = torch.mean(x, dim=(1, 2), keepdim=True) y_mean = torch.mean(y, dim=(1, 2), keepdim=True) x_normalized = x - x_mean y_normalized = y - y_mean norm = torch.sum(x_normalized * y_normalized, dim=(1, 2)) var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2)) denom = torch.sqrt(var + 1e-6) ncc = norm / (denom + 1e-6) # ignore pathces with low variances not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | ( torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance ) ncc[not_valid] = 1.0 score = 1 - ncc.clip(-1.0, 1.0) # 0->2: smaller, better return score[:, None, None, None] class MultiViewLoss(nn.Module): """compute multi-view consistency loss""" def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01): super(MultiViewLoss, self).__init__() self.patch_size = patch_size self.topk = topk self.min_patch_variance = min_patch_variance # TODO make metric configurable # self.ssim = SSIM(patch_size=patch_size) # self.ncc = NCC(patch_size=patch_size) self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance) self.iter = 0 def forward(self, patches: torch.Tensor, valid: torch.Tensor): """take the mim Args: patches (torch.Tensor): _description_ valid (torch.Tensor): _description_ Returns: _type_: _description_ """ num_imgs, num_rays, _, num_channels = patches.shape if num_rays <= 0: return torch.tensor(0.0).to(patches.device) ref_patches = ( patches[:1, ...] .reshape(1, num_rays, self.patch_size, self.patch_size, num_channels) .expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) .reshape(-1, self.patch_size, self.patch_size, num_channels) .permute(0, 3, 1, 2) ) # [N_src*N_rays, 3, patch_size, patch_size] src_patches = ( patches[1:, ...] .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) .reshape(-1, self.patch_size, self.patch_size, num_channels) .permute(0, 3, 1, 2) ) # [N_src*N_rays, 3, patch_size, patch_size] # apply same reshape to the valid mask src_patches_valid = ( valid[1:, ...] .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1) .reshape(-1, self.patch_size, self.patch_size, 1) .permute(0, 3, 1, 2) ) # [N_src*N_rays, 1, patch_size, patch_size] ssim = self.ssim(ref_patches.detach(), src_patches) ssim = torch.mean(ssim, dim=(1, 2, 3)) ssim = ssim.reshape(num_imgs - 1, num_rays) # ignore invalid patch by setting ssim error to very large value ssim_valid = ( src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays) ) # we should mask the error after we select the topk value, otherwise we might select far way patches that happens to be inside the image # ssim[torch.logical_not(ssim_valid)] = 1.1 # max ssim_error is 1 min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True) min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)] # TODO how to set this value for better visualization min_ssim[torch.logical_not(min_ssim_valid)] = 0.0 # max ssim_error is 1 if False: # visualization of topK error computations import cv2 import numpy as np vis_patch_num = num_rays K = min(100, vis_patch_num) image = ( patches[:, :vis_patch_num, :, :] .reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3) .permute(1, 2, 0, 3, 4) .reshape(vis_patch_num * self.patch_size, -1, 3) ) src_patches_reshaped = src_patches.reshape( num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size ).permute(1, 0, 3, 4, 2) idx = idx.permute(1, 0) selected_patch = ( src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] .permute(0, 2, 1, 3, 4) .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num] .reshape(-1, self.topk * self.patch_size, 3) ) # apply same reshape to the valid mask src_patches_valid_reshaped = src_patches_valid.reshape( num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size ).permute(1, 0, 3, 4, 2) selected_patch_valid = ( src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] .permute(0, 2, 1, 3, 4) .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num] .reshape(-1, self.topk * self.patch_size, 1) ) # valid to image selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float() # breakpoint() image = torch.cat([selected_patch_valid, selected_patch, image], dim=1) # select top rays with highest errors image = image.reshape(num_rays, self.patch_size, -1, 3) _, idx2 = torch.topk( torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6), k=K, largest=True, dim=0, sorted=True, ) image = image[idx2].reshape(K * self.patch_size, -1, 3) cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1]) self.iter += 1 if self.iter == 9: breakpoint() return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6) # sensor depth loss, adapted from https://github.com/dazinovic/neural-rgbd-surface-reconstruction/blob/main/losses.py # class SensorDepthLoss(nn.Module): # """Sensor Depth loss""" # def __init__(self, truncation: float): # super(SensorDepthLoss, self).__init__() # self.truncation = truncation # 0.05 * 0.3 5cm scaled # def forward(self, batch, outputs): # """take the mim # Args: # batch (Dict): inputs # outputs (Dict): outputs data from surface model # Returns: # l1_loss: l1 loss # freespace_loss: free space loss # sdf_loss: sdf loss # """ # depth_pred = outputs["depth"] # depth_gt = batch["sensor_depth"].to(depth_pred.device)[..., None] # valid_gt_mask = depth_gt > 0.0 # l1_loss = torch.sum(valid_gt_mask * torch.abs(depth_gt - depth_pred)) / (valid_gt_mask.sum() + 1e-6) # # free space loss and sdf loss # ray_samples = outputs["ray_samples"] # filed_outputs = outputs["field_outputs"] # pred_sdf = filed_outputs[FieldHeadNames.SDF][..., 0] # directions_norm = outputs["directions_norm"] # z_vals = ray_samples.frustums.starts[..., 0] / directions_norm # truncation = self.truncation # front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation)) # back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation)) # sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask) # num_fs_samples = front_mask.sum() # num_sdf_samples = sdf_mask.sum() # num_samples = num_fs_samples + num_sdf_samples + 1e-6 # fs_weight = 1.0 - num_fs_samples / num_samples # sdf_weight = 1.0 - num_sdf_samples / num_samples # free_space_loss = torch.mean((F.relu(truncation - pred_sdf) * front_mask) ** 2) * fs_weight # sdf_loss = torch.mean(((z_vals + pred_sdf) - depth_gt) ** 2 * sdf_mask) * sdf_weight # return l1_loss, free_space_loss, sdf_loss r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm. It is proposed in the ICCV2023 paper `S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`. Arguments: s3im_kernel_size (int): kernel size in ssim's convolution(default: 4) s3im_stride (int): stride in ssim's convolution(default: 4) s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10) s3im_patch_height (height): height of virtual patch(default: 64) """ class S3IM(torch.nn.Module): def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True): super(S3IM, self).__init__() self.s3im_kernel_size = s3im_kernel_size self.s3im_stride = s3im_stride self.s3im_repeat_time = s3im_repeat_time self.s3im_patch_height = s3im_patch_height self.size_average = size_average self.channel = 1 self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel) def gaussian(self, s3im_kernel_size, sigma): gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)]) return gauss/gauss.sum() def create_kernel(self, s3im_kernel_size, channel): _1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous()) return s3im_kernel def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None): mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def ssim_loss(self, img1, img2): """ img1, img2: torch.Tensor([b,c,h,w]) """ (_, channel, _, _) = img1.size() if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type(): s3im_kernel = self.s3im_kernel else: s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel) if img1.is_cuda: s3im_kernel = s3im_kernel.cuda(img1.get_device()) s3im_kernel = s3im_kernel.type_as(img1) self.s3im_kernel = s3im_kernel self.channel = channel return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride) def forward(self, src_vec, tar_vec): loss = 0.0 index_list = [] for i in range(self.s3im_repeat_time): if i == 0: tmp_index = torch.arange(len(tar_vec)) index_list.append(tmp_index) else: ran_idx = torch.randperm(len(tar_vec)) index_list.append(ran_idx) res_index = torch.cat(index_list) tar_all = tar_vec[res_index] src_all = src_vec[res_index] tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) loss = (1 - self.ssim_loss(src_patch, tar_patch)) return loss