# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Volumetric autoencoder (image -> encoding -> volume -> image) """ import inspect import time from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import models.utils from extensions.utils.utils import compute_raydirs @torch.jit.script def compute_raydirs_ref(pixelcoords : torch.Tensor, viewrot : torch.Tensor, focal : torch.Tensor, princpt : torch.Tensor): raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) raydir = F.normalize(raydir, dim=-1) return raydir @torch.jit.script def compute_rmbounds(viewpos : torch.Tensor, raydir : torch.Tensor, volradius : float): viewpos = viewpos / volradius # compute raymarching starting points with torch.no_grad(): t1 = (-1. - viewpos[:, None, None, :]) / raydir t2 = ( 1. - viewpos[:, None, None, :]) / raydir tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), torch.max(torch.min(t1[..., 1], t2[..., 1]), torch.min(t1[..., 2], t2[..., 2]))) tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), torch.min(torch.max(t1[..., 1], t2[..., 1]), torch.max(t1[..., 2], t2[..., 2]))) intersections = tmin < tmax t = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.) tmin = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.) tmax = torch.where(intersections, tmax, torch.zeros_like(tmin)) raypos = viewpos[:, None, None, :] + raydir * 0. tminmax = torch.stack([tmin, tmax], dim=-1) return raypos, tminmax class Autoencoder(nn.Module): def __init__(self, dataset, encoder, decoder, raymarcher, colorcal, volradius, bgmodel=None, encoderinputs=[], topology=None, imagemean=0., imagestd=1., vertmask=None, cudaraydirs=True): super(Autoencoder, self).__init__() self.encoder = encoder self.decoder = decoder self.raymarcher = raymarcher self.colorcal = colorcal self.volradius = volradius self.bgmodel = bgmodel self.encoderinputs = encoderinputs if hasattr(dataset, 'vertmean'): self.register_buffer("vertmean", torch.from_numpy(dataset.vertmean), persistent=False) self.vertstd = dataset.vertstd if hasattr(dataset, 'texmean'): self.register_buffer("texmean", torch.from_numpy(dataset.texmean), persistent=False) self.texstd = dataset.texstd self.imagemean = imagemean self.imagestd = imagestd self.cudaraydirs = cudaraydirs if vertmask is not None: self.register_buffer("vertmask", torch.from_numpy(vertmask), persistent=False) self.irgbmsestart = -1 def forward(self, camrot : torch.Tensor, campos : torch.Tensor, focal : torch.Tensor, princpt : torch.Tensor, camindex : Optional[torch.Tensor] = None, pixelcoords : Optional[torch.Tensor]=None, modelmatrix : Optional[torch.Tensor]=None, modelmatrixinv : Optional[torch.Tensor]=None, modelmatrix_next : Optional[torch.Tensor]=None, modelmatrixinv_next : Optional[torch.Tensor]=None, validinput : Optional[torch.Tensor]=None, avgtex : Optional[torch.Tensor]=None, avgtex_next : Optional[torch.Tensor]=None, verts : Optional[torch.Tensor]=None, verts_next : Optional[torch.Tensor]=None, fixedcamimage : Optional[torch.Tensor]=None, encoding : Optional[torch.Tensor]=None, image : Optional[torch.Tensor]=None, imagemask : Optional[torch.Tensor]=None, imagevalid : Optional[torch.Tensor]=None, bg : Optional[torch.Tensor]=None, renderoptions : dict ={}, trainiter : int=-1, evaliter : Optional[torch.Tensor]=None, outputlist : list=[], losslist : list=[], **kwargs): """ Parameters ---------- camrot : torch.Tensor [B, 3, 3] Rotation matrix of target view camera campos : torch.Tensor [B, 3] Position of target view camera focal : torch.Tensor [B, 2] Focal length of target view camera princpt : torch.Tensor [B, 2] Princple point of target view camera camindex : torch.Tensor[int32], optional [B] Camera index within the list of all cameras pixelcoords : torch.Tensor, optional [B, H', W', 2] Pixel coordinates to render of the target view camera modelmatrix : torch.Tensor, optional [B, 3, 3] Relative transform from the 'neutral' pose of object validinput : torch.Tensor, optional [B] Whether the current batch element is valid (used for missing images) avgtex : torch.Tensor, optional [B, 3, 1024, 1024] Texture map averaged from all viewpoints verts : torch.Tensor, optional [B, 7306, 3] Mesh vertex positions fixedcamimage : torch.Tensor, optional [B, 3, 512, 334] Camera images from a one or more cameras that are always the same (i.e., unrelated to target) encoding : torch.Tensor, optional [B, 256] Direct encodings (overrides encoder) image : torch.Tensor, optional [B, 3, H, W] Target image imagemask : torch.Tensor, optional [B, 1, H, W] Target image mask for computing reconstruction loss imagevalid : torch.Tensor, optional [B] bg : torch.Tensor, optional [B, 3, H, W] renderoptions : dict Rendering/raymarching options (e.g., stepsize, whether to output debug images, etc.) trainiter : int Training iteration number outputlist : list Values to return (e.g., image reconstruction, debug output) losslist : list Losses to output (e.g., image reconstruction loss, priors) Returns ------- result : dict Contains outputs specified in outputlist (e.g., image rgb reconstruction "irgbrec") losses : dict Losses to optimize """ resultout = {} resultlosses = {} aestart = time.time() # encode/get encoding # verts [6, 7306, 3] # avgtex [6, 3, 256, 256] if encoding is None: if "enctime" in outputlist: torch.cuda.synchronize() encstart = time.time() encout, enclosses = self.encoder( *[dict(verts=verts, avgtex=avgtex, fixedcamimage=fixedcamimage)[k] for k in self.encoderinputs], losslist=losslist) if "enctime" in outputlist: torch.cuda.synchronize() encend = time.time() resultout["enctime"] = encend - encstart # encoding [6, 256] encoding = encout["encoding"] resultlosses.update(enclosses) # compute relative viewing position if modelmatrixinv is not None: viewrot = torch.bmm(camrot, modelmatrixinv[:, :3, :3]) viewpos = torch.bmm((campos[:, :] - modelmatrixinv[:, :3, 3])[:, None, :], modelmatrixinv[:, :3, :3])[:, 0, :] else: viewrot = camrot viewpos = campos # decode volumetric representation if "dectime" in outputlist: torch.cuda.synchronize() decstart = time.time() if isinstance(self.decoder, torch.jit.ScriptModule): # torchscript requires statically typed dict renderoptionstyped : Dict[str, str] = {k: str(v) for k, v in renderoptions.items()} else: renderoptionstyped = renderoptions decout, declosses = self.decoder( encoding, viewpos, renderoptions=renderoptionstyped, trainiter=trainiter, evaliter=evaliter, losslist=losslist) if "dectime" in outputlist: torch.cuda.synchronize() decend = time.time() resultout["dectime"] = decend - decstart resultlosses.update(declosses) # compute vertex loss if "vertmse" in losslist: weight = validinput[:, None, None].expand_as(verts) if hasattr(self, "vertmask"): weight = weight * self.vertmask[None, :, None] vertsrecstd = (decout["verts"] - self.vertmean) / self.vertstd vertsqerr = weight * (verts - vertsrecstd) ** 2 vertmse = torch.sum(vertsqerr.view(vertsqerr.size(0), -1), dim=-1) vertmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) resultlosses["vertmse"] = (vertmse, vertmse_weight) # compute texture loss if "trgbmse" in losslist or "trgbsqerr" in outputlist: weight = (validinput[:, None, None, None] * texmask[:, None, :, :].float()).expand_as(tex).contiguous() # re-standardize texrecstd = (decout["tex"] - self.texmean.to("cuda")) / self.texstd texstd = (tex - self.texmean.to("cuda")) / self.texstd texsqerr = weight * (texstd - texrecstd) ** 2 if "trgbsqerr" in outputlist: resultout["trgbsqerr"] = texsqerr # texture rgb mean-squared-error if "trgbmse" in losslist: texmse = torch.sum(texsqerr.view(texsqerr.size(0), -1), dim=-1) texmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) resultlosses["trgbmse"] = (texmse, texmse_weight) # subsample depth, imagerec, imagerecmask if image is not None and pixelcoords.size()[1:3] != image.size()[2:4]: imagesize = torch.tensor(image.size()[3:1:-1], dtype=torch.float32, device=pixelcoords.device) else: imagesize = torch.tensor(pixelcoords.size()[2:0:-1], dtype=torch.float32, device=pixelcoords.device) samplecoords = pixelcoords * 2. / (imagesize[None, None, None, :] - 1.) - 1. # compute ray directions if self.cudaraydirs: raypos, raydir, tminmax = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, self.volradius) else: raydir = compute_raydirs_ref(pixelcoords, viewrot, focal, princpt) raypos, tminmax = compute_rmbounds(viewpos, raydir, self.volradius) if "dtstd" in renderoptions: renderoptions["dt"] = renderoptions["dt"] * \ torch.exp(torch.randn(1) * renderoptions.get("dtstd")).item() if renderoptions.get("unbiastminmax", False): stepsize = renderoptions["dt"] / self.volradius tminmax = torch.floor(tminmax / stepsize) * stepsize if renderoptions.get("tminmaxblocks", False): bx, by = renderoptions.get("blocksize", (8, 16)) H, W = tminmax.size(1), tminmax.size(2) tminmax = tminmax.view(tminmax.size(0), H // by, by, W // bx, bx, 2) tminmax = tminmax.amin(dim=[2, 4], keepdim=True) tminmax = tminmax.repeat(1, 1, by, 1, bx, 1) tminmax = tminmax.view(tminmax.size(0), H, W, 2) # raymarch if "rmtime" in outputlist: torch.cuda.synchronize() rmstart = time.time() # rayrgba [6, 4, 384, 384] rayrgba, rmlosses = self.raymarcher(raypos, raydir, tminmax, decout=decout, renderoptions=renderoptions, trainiter=trainiter, evaliter=evaliter, losslist=losslist) resultlosses.update(rmlosses) if "rmtime" in outputlist: torch.cuda.synchronize() rmend = time.time() resultout["rmtime"] = rmend - rmstart if isinstance(rayrgba, tuple): rayrgb, rayalpha = rayrgba else: rayrgb, rayalpha = rayrgba[:, :3, :, :].contiguous(), rayrgba[:, 3:4, :, :].contiguous() # beta distribution prior on final opacity if "alphapr" in losslist: alphaprior = torch.mean( torch.log(0.1 + rayalpha.view(rayalpha.size(0), -1)) + torch.log(0.1 + 1. - rayalpha.view(rayalpha.size(0), -1)) - -2.20727, dim=-1) resultlosses["alphapr"] = alphaprior # color correction if camindex is not None and not renderoptions.get("nocolcorrect", False): rayrgb = self.colorcal(rayrgb, camindex) # background decoder if self.bgmodel is not None and not renderoptions.get("nobg", False): if "bgtime" in outputlist: torch.cuda.synchronize() bgstart = time.time() raypos, raydir, tminmax = compute_raydirs(campos, camrot, focal, princpt, pixelcoords, self.volradius) rayposbeg = raypos + raydir * tminmax[..., 0:1] rayposend = raypos + raydir * tminmax[..., 1:2] bg = self.bgmodel(bg, camindex, campos, rayposend, raydir, samplecoords, trainiter=trainiter) # alpha matting if bg is not None: rayrgb = rayrgb + (1. - rayalpha) * bg if "bg" in outputlist: resultout["bg"] = bg if "bgtime" in outputlist: torch.cuda.synchronize() bgend = time.time() resultout["bgtime"] = bgend - bgstart if "irgbrec" in outputlist: resultout["irgbrec"] = rayrgb if "irgbarec" in outputlist: resultout["irgbarec"] = torch.cat([rayrgb, rayalpha], dim=1) if "irgbflip" in outputlist: resultout["irgbflip"] = torch.cat([rayrgb[i:i+1] if i % 4 < 2 else image[i:i+1] for i in range(image.size(0))], dim=0) # image rgb loss if image is not None and trainiter > self.irgbmsestart: # subsample image if pixelcoords.size()[1:3] != image.size()[2:4]: image = F.grid_sample(image, samplecoords, align_corners=True) if imagemask is not None: imagemask = F.grid_sample(imagemask, samplecoords, align_corners=True) # compute reconstruction loss weighting weight = torch.ones_like(image) * validinput[:, None, None, None] if imagevalid is not None: weight = weight * imagevalid[:, None, None, None] if imagemask is not None: weight = weight * imagemask if "irgbsqerr" in outputlist: irgbsqerr_nonorm = (weight * (image - rayrgb) ** 2).contiguous() resultout["irgbsqerr"] = torch.sqrt(irgbsqerr_nonorm.mean(dim=1, keepdim=True)) # standardize rayrgb = (rayrgb - self.imagemean) / self.imagestd image = (image - self.imagemean) / self.imagestd irgbsqerr = (weight * (image - rayrgb) ** 2).contiguous() if "irgbmse" in losslist: irgbmse = torch.sum(irgbsqerr.view(irgbsqerr.size(0), -1), dim=-1) irgbmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) resultlosses["irgbmse"] = (irgbmse, irgbmse_weight) aeend = time.time() if "aetime" in outputlist: resultout["aetime"] = aeend - aestart return resultout, resultlosses