# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ MVP decoder """ import math from typing import Optional, Dict, List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import models.utils from models.utils import LinearELR, ConvTranspose2dELR, ConvTranspose3dELR @torch.jit.script def compute_postex(geo, idxim, barim, volradius : float): # compute 3d coordinates of each texel in uv map return ( barim[None, :, :, 0, None] * geo[:, idxim[:, :, 0], :] + barim[None, :, :, 1, None] * geo[:, idxim[:, :, 1], :] + barim[None, :, :, 2, None] * geo[:, idxim[:, :, 2], :] ).permute(0, 3, 1, 2) / volradius @torch.jit.script def compute_tbn(v0, v1, v2, vt0, vt1, vt2): v01 = v1 - v0 v02 = v2 - v0 vt01 = vt1 - vt0 vt02 = vt2 - vt0 f = 1. / (vt01[None, :, :, 0] * vt02[None, :, :, 1] - vt01[None, :, :, 1] * vt02[None, :, :, 0]) tangent = f[:, :, :, None] * torch.stack([ v01[:, :, :, 0] * vt02[None, :, :, 1] - v02[:, :, :, 0] * vt01[None, :, :, 1], v01[:, :, :, 1] * vt02[None, :, :, 1] - v02[:, :, :, 1] * vt01[None, :, :, 1], v01[:, :, :, 2] * vt02[None, :, :, 1] - v02[:, :, :, 2] * vt01[None, :, :, 1]], dim=-1) tangent = F.normalize(tangent, dim=-1) normal = torch.cross(v01, v02, dim=3) normal = F.normalize(normal, dim=-1) bitangent = torch.cross(tangent, normal, dim=3) bitangent = F.normalize(bitangent, dim=-1) # create matrix primrotmesh = torch.stack((tangent, bitangent, normal), dim=-1) return primrotmesh class Reshape(nn.Module): def __init__(self, *args): super(Reshape, self).__init__() self.shape = args def forward(self, x): return x.view(self.shape) # RGBA decoder class SlabContentDecoder(nn.Module): def __init__(self, nprims, primsize, inch, outch, chstart=256, hstart=4, texwarp=False, elr=True, norm=None, mod=False, ub=True, upconv=None, penultch=None, use3dconv=False, reduced3dch=False): super(SlabContentDecoder, self).__init__() assert not texwarp assert upconv == None self.nprims = nprims self.primsize = primsize self.nprimy = int(math.sqrt(nprims)) self.nprimx = nprims // self.nprimy assert nprims == self.nprimx * self.nprimy self.slabw = self.nprimx * primsize[0] self.slabh = self.nprimy * primsize[1] self.slabd = primsize[2] nlayers = int(math.log2(min(self.slabw, self.slabh))) - int(math.log2(hstart)) nlayers3d = int(math.log2(self.slabd)) nlayers2d = nlayers - nlayers3d lastch = chstart dims = (1, hstart, hstart * self.nprimx // self.nprimy) layers = [] layers.append(LinearELR(inch, chstart*dims[1]*dims[2], act=nn.LeakyReLU(0.2))) layers.append(Reshape(-1, chstart, dims[1], dims[2])) for i in range(nlayers): nextch = lastch if i % 2 == 0 else lastch // 2 if use3dconv and reduced3dch and i >= nlayers2d: nextch //= 2 if i == nlayers - 2 and penultch is not None: nextch = penultch if use3dconv and i >= nlayers2d: if i == nlayers2d: layers.append(Reshape(-1, lastch, 1, dims[1], dims[2])) layers.append(ConvTranspose3dELR( lastch, (outch if i == nlayers - 1 else nextch), 4, 2, 1, ub=(dims[0]*2, dims[1]*2, dims[2]*2) if ub else None, norm=None if i == nlayers - 1 else norm, act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) )) else: layers.append(ConvTranspose2dELR( lastch, (outch * primsize[2] if i == nlayers - 1 else nextch), 4, 2, 1, ub=(dims[1]*2, dims[2]*2) if ub else None, norm=None if i == nlayers - 1 else norm, act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) )) lastch = nextch dims = (dims[0] * (2 if use3dconv and i >= nlayers2d else 1), dims[1] * 2, dims[2] * 2) self.mod = nn.Sequential(*layers) def forward(self, enc, renderoptions : Dict[str, str], trainiter : Optional[int]=None): x = self.mod(enc) algo = renderoptions.get("algo") chlast = renderoptions.get("chlast") if chlast is not None and bool(chlast): # reorder channels last if len(x.size()) == 5: outch = x.size(1) x = x.view(x.size(0), outch, self.primsize[2], self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) x = x.permute(0, 3, 5, 2, 4, 6, 1) x = x.reshape(x.size(0), self.nprims, self.primsize[2], self.primsize[1], self.primsize[0], outch) else: outch = x.size(1) // self.primsize[2] x = x.view(x.size(0), self.primsize[2], outch, self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) x = x.permute(0, 3, 5, 1, 4, 6, 2) x = x.reshape(x.size(0), self.nprims, self.primsize[2], self.primsize[1], self.primsize[0], outch) else: if len(x.size()) == 5: outch = x.size(1) x = x.view(x.size(0), outch, self.primsize[2], self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) x = x.permute(0, 3, 5, 1, 2, 4, 6) x = x.reshape(x.size(0), self.nprims, outch, self.primsize[2], self.primsize[1], self.primsize[0]) else: outch = x.size(1) // self.primsize[2] x = x.view(x.size(0), self.primsize[2], outch, self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) x = x.permute(0, 3, 5, 2, 1, 4, 6) x = x.reshape(x.size(0), self.nprims, outch, self.primsize[2], self.primsize[1], self.primsize[0]) return x def get_dec(dectype, **kwargs): if dectype == "slab2d": return SlabContentDecoder(**kwargs, use3dconv=False) elif dectype == "slab2d3d": return SlabContentDecoder(**kwargs, use3dconv=True) elif dectype == "slab2d3dv2": return SlabContentDecoder(**kwargs, use3dconv=True, reduced3dch=True) else: raise # motion model for the delta from mesh-based position/orientation class DeconvMotionModel(nn.Module): def __init__(self, nprims, inch, outch, chstart=1024, norm=None, mod=False, elr=True): super(DeconvMotionModel, self).__init__() self.nprims = nprims self.nprimy = int(math.sqrt(nprims)) self.nprimx = nprims // int(math.sqrt(nprims)) assert nprims == self.nprimx * self.nprimy nlayers = int(math.log2(min(self.nprimx, self.nprimy))) ch0, ch1 = chstart, chstart // 2 layers = [] layers.append(LinearELR(inch, ch0, norm=norm, act=nn.LeakyReLU(0.2))) layers.append(Reshape(-1, ch0, 1, self.nprimx // self.nprimy)) dims = (1, 1, self.nprimx // self.nprimy) for i in range(nlayers): layers.append(ConvTranspose2dELR( ch0, (outch if i == nlayers - 1 else ch1), 4, 2, 1, norm=None if i == nlayers - 1 else norm, act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) )) if ch0 == ch1: ch1 = ch0 // 2 else: ch0 = ch1 self.mod = nn.Sequential(*layers) def forward(self, encoding): out = self.mod(encoding) out = out.view(encoding.size(0), 9, -1).permute(0, 2, 1).contiguous() primposdelta = out[:, :, 0:3] primrvecdelta = out[:, :, 3:6] primscaledelta = out[:, :, 6:9] return primposdelta, primrvecdelta, primscaledelta def get_motion(motiontype, **kwargs): if motiontype == "deconv": return DeconvMotionModel(**kwargs) else: raise class Decoder(nn.Module): def __init__(self, vt, vertmean, vertstd, idxim, tidxim, barim, volradius, dectype="slab2d", nprims=512, primsize=(32, 32, 32), chstart=256, penultch=None, condsize=0, motiontype="deconv", warptype=None, warpprimsize=None, sharedrgba=False, norm=None, mod=False, elr=True, scalemult=2., nogeo=False, notplateact=False, postrainstart=-1, alphatrainstart=-1, renderoptions={}, **kwargs): """ Parameters ---------- vt : numpy.array [V, 2] mesh vertex texture coordinates vertmean : numpy.array [V, 3] mesh vertex position average (average over time) vertstd : float mesh vertex position standard deviation (over time) idxim : torch.Tensor texture map of triangle indices tidxim : torch.Tensor texture map of texture triangle indices barim : torch.Tensor texture map of barycentric coordinates volradius : float radius of bounding volume of scene dectype : string type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2" nprims : int number of primitives primsize : Tuple[int, int, int] size of primitive dimensions postrainstart : int training iterations to start learning position, rotation, and scaling (i.e., primitives stay frozen until this iteration number) condsize : int unused motiontype : string motion model, options are "linear" and "deconv" warptype : string warp model, options are "same" to use same architecture as content or None sharedrgba : bool True to use 1 branch to output rgba, False to use 1 branch for rgb and 1 branch for alpha """ super(Decoder, self).__init__() self.volradius = volradius self.postrainstart = postrainstart self.alphatrainstart = alphatrainstart self.nprims = nprims self.primsize = primsize self.motiontype = motiontype self.nogeo = nogeo self.notplateact = notplateact self.scalemult = scalemult self.enc = LinearELR(256 + condsize, 256) # vertex output if not self.nogeo: self.geobranch = LinearELR(256, vertmean.numel(), norm=None) # primitive motion delta decoder self.motiondec = get_motion(motiontype, nprims=nprims, inch=256, outch=9, norm=norm, mod=mod, elr=elr, **kwargs) # slab decoder (RGBA) if sharedrgba: self.rgbadec = get_dec(dectype, nprims=nprims, primsize=primsize, inch=256+3, outch=4, norm=norm, mod=mod, elr=elr, penultch=penultch, **kwargs) if renderoptions.get("half", False): self.rgbadec = self.rgbadec.half() if renderoptions.get("chlastconv", False): self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last) else: self.rgbdec = get_dec(dectype, nprims=nprims, primsize=primsize, inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, penultch=penultch, **kwargs) self.alphadec = get_dec(dectype, nprims=nprims, primsize=primsize, inch=256, outch=1, chstart=chstart, norm=norm, mod=mod, elr=elr, penultch=penultch, **kwargs) self.rgbadec = None if renderoptions.get("half", False): self.rgbdec = self.rgbdec.half() self.alphadec = self.alphadec.half() if renderoptions.get("chlastconv", False): self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last) self.alphadec = self.alphadec.to(memory_format=torch.channels_last) # warp field decoder if warptype is not None: self.warpdec = get_dec(warptype, nprims=nprims, primsize=warpprimsize, inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs) else: self.warpdec = None # vertex/triangle/mesh topology data if vt is not None: vt = torch.tensor(vt) if not isinstance(vt, torch.Tensor) else vt self.register_buffer("vt", vt, persistent=False) if vertmean is not None: self.register_buffer("vertmean", vertmean, persistent=False) self.vertstd = vertstd idxim = torch.tensor(idxim) if not isinstance(idxim, torch.Tensor) else idxim tidxim = torch.tensor(tidxim) if not isinstance(tidxim, torch.Tensor) else tidxim barim = torch.tensor(barim) if not isinstance(barim, torch.Tensor) else barim self.register_buffer("idxim", idxim.long(), persistent=False) self.register_buffer("tidxim", tidxim.long(), persistent=False) self.register_buffer("barim", barim, persistent=False) def forward(self, encoding, viewpos, condinput : Optional[torch.Tensor]=None, renderoptions : Optional[Dict[str, str]]=None, trainiter : int=-1, evaliter : Optional[torch.Tensor]=None, losslist : Optional[List[str]]=None, modelmatrix : Optional[torch.Tensor]=None): """ Parameters ---------- encoding : torch.Tensor [B, 256] Encoding of current frame viewpos : torch.Tensor [B, 3] Viewing position of target camera view condinput : torch.Tensor [B, ?] Additional conditioning input (e.g., headpose) renderoptions : dict Options for rendering (e.g., rendering debug images) trainiter : int, Current training iteration losslist : list, List of losses to compute and return Returns ------- result : dict, Contains predicted vertex positions, primitive contents and locations, scaling, and orientation, and any losses. """ assert renderoptions is not None assert losslist is not None if condinput is not None: encoding = torch.cat([encoding, condinput], dim=1) encoding = self.enc(encoding) viewdirs = F.normalize(viewpos, dim=1) if int(math.sqrt(self.nprims)) ** 2 == self.nprims: nprimsy = int(math.sqrt(self.nprims)) else: nprimsy = int(math.sqrt(self.nprims // 2)) nprimsx = self.nprims // nprimsy assert nprimsx * nprimsy == self.nprims if not self.nogeo: # decode mesh vertices # geo [6, 7306, 3] geo = self.geobranch(encoding) geo = geo.view(encoding.size(0), -1, 3) geo = geo * self.vertstd + self.vertmean # placement of primitives on mesh uvheight, uvwidth = self.barim.size(0), self.barim.size(1) stridey = uvheight // nprimsy stridex = uvwidth // nprimsx # get subset of vertices and texture map coordinates to compute TBN matrix v0 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 0], :] v1 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 1], :] v2 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 2], :] vt0 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 0], :] vt1 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 1], :] vt2 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 2], :] # [6, 256, 3] primposmesh = ( self.barim[None, stridey//2::stridey, stridex//2::stridex, 0, None] * v0 + self.barim[None, stridey//2::stridey, stridex//2::stridex, 1, None] * v1 + self.barim[None, stridey//2::stridey, stridex//2::stridex, 2, None] * v2 ).view(v0.size(0), self.nprims, 3) / self.volradius # compute TBN matrix # primrotmesh [6, 16, 16, 3, 3] primrotmesh = compute_tbn(v0, v1, v2, vt0, vt1, vt2) # decode motion deltas [6, 256, 3] primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding) if trainiter <= self.postrainstart: primposdelta = primposdelta * 0. primrvecdelta = primrvecdelta * 0. primscaledelta = primscaledelta * 0. # compose mesh transform with deltas primpos = primposmesh + primposdelta * 0.01 primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.01) primrot = torch.bmm( primrotmesh.view(-1, 3, 3), primrotdelta.view(-1, 3, 3)).view(encoding.size(0), self.nprims, 3, 3) primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * torch.exp(primscaledelta * 0.01) primtransf = None else: geo = None # decode motion deltas primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding) if trainiter <= self.postrainstart: primposdelta = primposdelta * 0. primrvecdelta = primrvecdelta * 0. primscaledelta = primscaledelta * 0. + 1. primpos = primposdelta * 0.3 primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.3) primrot = torch.exp(primrotdelta * 0.01) primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * primscaledelta primtransf = None # options algo = renderoptions.get("algo") chlast = renderoptions.get("chlast") half = renderoptions.get("half") if self.rgbadec is not None: # shared rgb and alpha branch scale = torch.tensor([25., 25., 25., 1.], device=encoding.device) bias = torch.tensor([100., 100., 100., 0.], device=encoding.device) if chlast is not None and bool(chlast): scale = scale[None, None, None, None, None, :] bias = bias[None, None, None, None, None, :] else: scale = scale[None, None, :, None, None, None] bias = bias[None, None, :, None, None, None] templatein = torch.cat([encoding, viewdirs], dim=1) if half is not None and bool(half): templatein = templatein.half() template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions) template = bias + scale * template if not self.notplateact: template = F.relu(template) if half is not None and bool(half): template = template.float() else: templatein = torch.cat([encoding, viewdirs], dim=1) if half is not None and bool(half): templatein = templatein.half() # primrgb [6, 256, 32, 32, 32, 3] -> [B, 256, primsize, 3] primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions) primrgb = primrgb * 25. + 100. if not self.notplateact: primrgb = F.relu(primrgb) templatein = encoding if half is not None and bool(half): templatein = templatein.half() primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions) if not self.notplateact: primalpha = F.relu(primalpha) if trainiter <= self.alphatrainstart: primalpha = primalpha * 0. + 1. if algo is not None and int(algo) == 4: template = torch.cat([primrgb, primalpha], dim=-1) elif chlast is not None and bool(chlast): template = torch.cat([primrgb, primalpha], dim=-1) else: template = torch.cat([primrgb, primalpha], dim=2) if half is not None and bool(half): template = template.float() if self.warpdec is not None: warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01 warp = warp + torch.stack(torch.meshgrid( torch.linspace(-1., 1., self.primsize[2], device=encoding.device), torch.linspace(-1., 1., self.primsize[1], device=encoding.device), torch.linspace(-1., 1., self.primsize[0], device=encoding.device))[::-1], dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :] else: warp = None # debugging / visualization viewaxes = renderoptions.get("viewaxes") colorprims = renderoptions.get("colorprims") viewslab = renderoptions.get("viewslab") # add axes to primitives if viewaxes is not None and bool(viewaxes): template[:, :, 3, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550. template[:, :, 0, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550. template[:, :, 3, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550. template[:, :, 1, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550. template[:, :, 3, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550. template[:, :, 2, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550. # give each primitive a unique color if colorprims is not None and bool(colorprims): lightdir = -torch.tensor([1., 1., 1.], device=template.device) lightdir = lightdir / torch.sqrt(torch.sum(lightdir ** 2)) zz, yy, xx = torch.meshgrid( torch.linspace(-1., 1., self.primsize[2], device=template.device), torch.linspace(-1., 1., self.primsize[1], device=template.device), torch.linspace(-1., 1., self.primsize[0], device=template.device)) primnormalx = torch.where( (torch.abs(xx) >= torch.abs(yy)) & (torch.abs(xx) >= torch.abs(zz)), torch.sign(xx) * torch.ones_like(xx), torch.zeros_like(xx)) primnormaly = torch.where( (torch.abs(yy) >= torch.abs(xx)) & (torch.abs(yy) >= torch.abs(zz)), torch.sign(yy) * torch.ones_like(xx), torch.zeros_like(xx)) primnormalz = torch.where( (torch.abs(zz) >= torch.abs(xx)) & (torch.abs(zz) >= torch.abs(yy)), torch.sign(zz) * torch.ones_like(xx), torch.zeros_like(xx)) primnormal = torch.stack([primnormalx, primnormaly, primnormalz], dim=-1) primnormal = F.normalize(primnormal, dim=-1) torch.manual_seed(123456) gridz, gridy, gridx = torch.meshgrid( torch.linspace(-1., 1., self.primsize[2], device=encoding.device), torch.linspace(-1., 1., self.primsize[1], device=encoding.device), torch.linspace(-1., 1., self.primsize[0], device=encoding.device)) grid = torch.stack([gridx, gridy, gridz], dim=-1) if chlast is not None and chlast: template[:] = torch.rand(1, template.size(1), 1, 1, 1, template.size(-1), device=template.device) * 255. template[:, :, :, :, :, 3] = 1000. else: template[:] = torch.rand(1, template.size(1), template.size(2), 1, 1, 1, device=template.device) * 255. template[:, :, 3, :, :, :] = 1000. if chlast is not None and chlast: lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2) template[:, :, :, :, :, :3] *= 1.2 * torch.sum( lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, :, :, :, None].clamp(min=0.05) else: lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2) template[:, :, :3, :, :, :] *= 1.2 * torch.sum( lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, None, :, :, :].clamp(min=0.05) # view slab as a 2d grid if viewslab is not None and bool(viewslab): assert evaliter is not None yy, xx = torch.meshgrid( torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device), torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device)) primpos0 = torch.stack([xx*1.5, 0.75-yy*1.5, xx*0.+0.5], dim=-1)[None, :, :, :].repeat(template.size(0), 1, 1, 1).view(-1, self.nprims, 3) primrot0 = torch.eye(3, device=template.device)[None, None, :, :].repeat(template.size(0), self.nprims, 1, 1) primrot0.data[:, :, 1, 1] *= -1. primscale0 = torch.ones((template.size(0), self.nprims, 3), device=template.device) * math.sqrt(self.nprims) * 1.25 #* 0.5 blend = ((evaliter - 256.) / 64.).clamp(min=0., max=1.)[:, None, None] blend = 3 * blend ** 2 - 2 * blend ** 3 primpos = (1. - blend) * primpos0 + blend * primpos primrot = models.utils.rotation_interp(primrot0, primrot, blend) primscale = torch.exp((1. - blend) * torch.log(primscale0) + blend * torch.log(primscale)) losses = {} # prior on primitive volume if "primvolsum" in losslist: losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1) if "logprimscalevar" in losslist: logprimscale = torch.log(primscale) logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True) losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2) result = { "template": template, "primpos": primpos, "primrot": primrot, "primscale": primscale} if primtransf is not None: result["primtransf"] = primtransf if warp is not None: result["warp"] = warp if geo is not None: result["verts"] = geo return result, losses