Spaces:
Running
on
L4
Running
on
L4
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Neural Volumes decoder """ | |
import math | |
from typing import Optional, Dict, List | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import models.utils | |
from models.utils import LinearELR, ConvTranspose2dELR, ConvTranspose3dELR | |
class Reshape(nn.Module): | |
def __init__(self, *args): | |
super(Reshape, self).__init__() | |
self.shape = args | |
def forward(self, x): | |
return x.view(self.shape) | |
class ContentDecoder(nn.Module): | |
def __init__(self, primsize, inch, outch, chstart=256, hstart=4, | |
texwarp=False, elr=True, norm=None, mod=False, ub=True, upconv=None, | |
penultch=None): | |
super(ContentDecoder, self).__init__() | |
assert not texwarp | |
assert upconv == None | |
self.primsize = primsize | |
nlayers = int(math.log2(self.primsize / hstart)) | |
lastch = chstart | |
dims = (hstart, hstart, hstart) | |
layers = [] | |
layers.append(LinearELR(inch, chstart*dims[0]*dims[1]*dims[2], act=nn.LeakyReLU(0.2))) | |
layers.append(Reshape(-1, chstart, dims[0], dims[1], dims[2])) | |
for i in range(nlayers): | |
nextch = lastch if i % 2 == 0 else lastch // 2 | |
if i == nlayers - 2 and penultch is not None: | |
nextch = penultch | |
layers.append(ConvTranspose3dELR( | |
lastch, | |
(outch if i == nlayers - 1 else nextch), | |
4, 2, 1, | |
ub=(dims[0]*2, dims[1]*2, dims[2]*2) if ub else None, | |
norm=None if i == nlayers - 1 else norm, | |
act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) | |
)) | |
lastch = nextch | |
dims = (dims[0] * 2, dims[1] * 2, dims[2] * 2) | |
self.mod = nn.Sequential(*layers) | |
def forward(self, enc, renderoptions : Dict[str, str], trainiter : Optional[int]=None): | |
x = self.mod(enc) | |
algo = renderoptions.get("algo") | |
chlast = renderoptions.get("chlast") | |
if chlast is not None and bool(chlast): | |
# reorder channels last | |
outch = x.size(1) | |
x = x.permute(0, 2, 3, 4, 1)[:, None, :, :, :, :].contiguous() | |
else: | |
outch = x.size(1) | |
x = x[:, None, :, :, :, :].contiguous() | |
return x | |
def get_dec(dectype, **kwargs): | |
if dectype == "conv": | |
return ContentDecoder(**kwargs) | |
else: | |
raise | |
class Decoder(nn.Module): | |
def __init__(self, | |
volradius, | |
dectype="conv", | |
primsize=128, | |
chstart=256, | |
penultch=None, | |
condsize=0, | |
warptype="conv", | |
warpprimsize=32, | |
sharedrgba=False, | |
norm=None, | |
mod=False, | |
elr=True, | |
notplateact=False, | |
postrainstart=-1, | |
alphatrainstart=-1, | |
renderoptions={}, | |
**kwargs): | |
""" | |
Parameters | |
---------- | |
volradius : float | |
radius of bounding volume of scene | |
dectype : string | |
type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2" | |
primsize : Tuple[int, int, int] | |
size of primitive dimensions | |
postrainstart : int | |
training iterations to start learning position, rotation, and | |
scaling (i.e., primitives stay frozen until this iteration number) | |
condsize : int | |
unused | |
motiontype : string | |
motion model, options are "linear" and "deconv" | |
warptype : string | |
warp model, options are "same" to use same architecture as content | |
or None | |
sharedrgba : bool | |
True to use 1 branch to output rgba, False to use 1 branch for rgb | |
and 1 branch for alpha | |
""" | |
super(Decoder, self).__init__() | |
self.volradius = volradius | |
self.postrainstart = postrainstart | |
self.alphatrainstart = alphatrainstart | |
self.primsize = primsize | |
self.warpprimsize = warpprimsize | |
self.notplateact = notplateact | |
self.enc = LinearELR(256 + condsize, 256) | |
# slab decoder (RGBA) | |
if sharedrgba: | |
self.rgbadec = get_dec(dectype, primsize=primsize, | |
inch=256+3, outch=4, norm=norm, mod=mod, elr=elr, | |
penultch=penultch, **kwargs) | |
if renderoptions.get("half", False): | |
self.rgbadec = self.rgbadec.half() | |
if renderoptions.get("chlastconv", False): | |
self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last) | |
else: | |
self.rgbdec = get_dec(dectype, primsize=primsize, | |
inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod, | |
elr=elr, penultch=penultch, **kwargs) | |
self.alphadec = get_dec(dectype, primsize=primsize, | |
inch=256, outch=1, chstart=chstart, norm=norm, mod=mod, | |
elr=elr, penultch=penultch, **kwargs) | |
self.rgbadec = None | |
if renderoptions.get("half", False): | |
self.rgbdec = self.rgbdec.half() | |
self.alphadec = self.alphadec.half() | |
if renderoptions.get("chlastconv", False): | |
self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last) | |
self.alphadec = self.alphadec.to(memory_format=torch.channels_last) | |
# warp field decoder | |
if warptype is not None: | |
self.warpdec = get_dec(warptype, primsize=warpprimsize, | |
inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs) | |
else: | |
self.warpdec = None | |
def forward(self, | |
encoding, | |
viewpos, | |
condinput : Optional[torch.Tensor]=None, | |
renderoptions : Optional[Dict[str, str]]=None, | |
trainiter : int=-1, | |
evaliter : Optional[torch.Tensor]=None, | |
losslist : Optional[List[str]]=None, | |
modelmatrix : Optional[torch.Tensor]=None): | |
""" | |
Parameters | |
---------- | |
encoding : torch.Tensor [B, 256] | |
Encoding of current frame | |
viewpos : torch.Tensor [B, 3] | |
Viewing position of target camera view | |
condinput : torch.Tensor [B, ?] | |
Additional conditioning input (e.g., headpose) | |
renderoptions : dict | |
Options for rendering (e.g., rendering debug images) | |
trainiter : int, | |
Current training iteration | |
losslist : list, | |
List of losses to compute and return | |
Returns | |
------- | |
result : dict, | |
Contains predicted vertex positions, primitive contents and | |
locations, scaling, and orientation, and any losses. | |
""" | |
assert renderoptions is not None | |
assert losslist is not None | |
if condinput is not None: | |
encoding = torch.cat([encoding, condinput], dim=1) | |
encoding = self.enc(encoding) | |
viewdirs = F.normalize(viewpos, dim=1) | |
primpos = torch.zeros(encoding.size(0), 1, 3, device=encoding.device) | |
primrot = torch.eye(3, device=encoding.device)[None, None, :, :].repeat(encoding.size(0), 1, 1, 1) | |
primscale = torch.ones(encoding.size(0), 1, 3, device=encoding.device) | |
# options | |
algo = renderoptions.get("algo") | |
chlast = renderoptions.get("chlast") | |
half = renderoptions.get("half") | |
if self.rgbadec is not None: | |
# shared rgb and alpha branch | |
scale = torch.tensor([25., 25., 25., 1.], device=encoding.device) | |
bias = torch.tensor([100., 100., 100., 0.], device=encoding.device) | |
if chlast is not None and bool(chlast): | |
scale = scale[None, None, None, None, None, :] | |
bias = bias[None, None, None, None, None, :] | |
else: | |
scale = scale[None, None, :, None, None, None] | |
bias = bias[None, None, :, None, None, None] | |
templatein = torch.cat([encoding, viewdirs], dim=1) | |
if half is not None and bool(half): | |
templatein = templatein.half() | |
template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions) | |
template = bias + scale * template | |
if not self.notplateact: | |
template = F.relu(template) | |
if half is not None and bool(half): | |
template = template.float() | |
else: | |
templatein = torch.cat([encoding, viewdirs], dim=1) | |
if half is not None and bool(half): | |
templatein = templatein.half() | |
primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions) | |
primrgb = primrgb * 25. + 100. | |
if not self.notplateact: | |
primrgb = F.relu(primrgb) | |
templatein = encoding | |
if half is not None and bool(half): | |
templatein = templatein.half() | |
primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions) | |
if not self.notplateact: | |
primalpha = F.relu(primalpha) | |
if trainiter <= self.alphatrainstart: | |
primalpha = primalpha * 0. + 1. | |
if algo is not None and int(algo) == 4: | |
template = torch.cat([primrgb, primalpha], dim=-1) | |
elif chlast is not None and bool(chlast): | |
template = torch.cat([primrgb, primalpha], dim=-1) | |
else: | |
template = torch.cat([primrgb, primalpha], dim=2) | |
if half is not None and bool(half): | |
template = template.float() | |
if self.warpdec is not None: | |
warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01 | |
warp = warp + torch.stack(torch.meshgrid( | |
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), | |
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), | |
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device))[::-1], | |
dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :] | |
warp = warp.contiguous() | |
else: | |
warp = None | |
losses = {} | |
# prior on primitive volume | |
if "primvolsum" in losslist: | |
losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1) | |
if "logprimscalevar" in losslist: | |
logprimscale = torch.log(primscale) | |
logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True) | |
losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2) | |
result = { | |
"template": template, | |
"primpos": primpos, | |
"primrot": primrot, | |
"primscale": primscale} | |
if warp is not None: | |
result["warp"] = warp | |
return result, losses | |