FrozenBurning
single view to 3D init release
81ecb2b
raw
history blame
No virus
2.39 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, List
import torch
import torch.nn as nn
from models.utils import LinearELR, Conv2dELR, Downsample2d
class Encoder(torch.nn.Module):
def __init__(self, ninputs, size, nlayers=7, conv=Conv2dELR, lin=LinearELR):
super(Encoder, self).__init__()
self.ninputs = ninputs
height, width = size
self.nlayers = nlayers
ypad = ((height + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - height
xpad = ((width + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - width
self.pad = nn.ZeroPad2d((xpad // 2, xpad - xpad // 2, ypad // 2, ypad - ypad // 2))
self.downwidth = ((width + 2 ** nlayers - 1) // 2 ** nlayers)
self.downheight = ((height + 2 ** nlayers - 1) // 2 ** nlayers)
# compile layers
layers = []
inch, outch = 3, 64
for i in range(nlayers):
layers.append(conv(inch, outch, 4, 2, 1, norm="demod", act=nn.LeakyReLU(0.2)))
if inch == outch:
outch = inch * 2
else:
inch = outch
if outch > 256:
outch = 256
self.down1 = nn.ModuleList([nn.Sequential(*layers)
for i in range(self.ninputs)])
self.down2 = lin(256 * self.ninputs * self.downwidth * self.downheight, 512, norm="demod", act=nn.LeakyReLU(0.2))
self.mu = lin(512, 256)
self.logstd = lin(512, 256)
def forward(self, x, losslist : Optional[List[str]]=None):
assert losslist is not None
x = self.pad(x)
x = [self.down1[i](x[:, i*3:(i+1)*3, :, :]).view(x.size(0), 256 * self.downwidth * self.downheight)
for i in range(self.ninputs)]
x = torch.cat(x, dim=1)
x = self.down2(x)
mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01
if self.training:
z = mu + torch.exp(logstd) * torch.randn(*logstd.size(), device=logstd.device)
else:
z = mu
losses = {}
if "kldiv" in losslist:
losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1)
return {"encoding": z}, losses