# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """PyTorch utilities""" from collections import OrderedDict from itertools import islice import math import operator from typing import Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def xaviermultiplier(m, gain): if isinstance(m, nn.Conv1d): ksize = m.kernel_size[0] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.ConvTranspose1d): ksize = m.kernel_size[0] // m.stride[0] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.Conv2d): ksize = m.kernel_size[0] * m.kernel_size[1] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.ConvTranspose2d): ksize = m.kernel_size[0] * m.kernel_size[1] // m.stride[0] // m.stride[1] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.Conv3d): ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.ConvTranspose3d): ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // m.stride[0] // m.stride[1] // m.stride[2] n1 = m.in_channels n2 = m.out_channels std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) elif isinstance(m, nn.Linear): n1 = m.in_features n2 = m.out_features std = gain * math.sqrt(2.0 / (n1 + n2)) else: return None return std ### normal initialization routines def xavier_uniform_(m, gain): std = xaviermultiplier(m, gain) m.weight.data.uniform_(-std * math.sqrt(3.0), std * math.sqrt(3.0)) def initmod(m, gain=1.0, weightinitfunc=xavier_uniform_): validclasses = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] if any([isinstance(m, x) for x in validclasses]): weightinitfunc(m, gain) if hasattr(m, 'bias') and isinstance(m.bias, torch.Tensor): m.bias.data.zero_() # blockwise initialization for transposed convs if isinstance(m, nn.ConvTranspose2d): # hardcoded for stride=2 for now m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] if isinstance(m, nn.ConvTranspose3d): # hardcoded for stride=2 for now m.weight.data[:, :, 0::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 0::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 0::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 1::2, 0::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 1::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 1::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] m.weight.data[:, :, 1::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] if isinstance(m, Conv2dWNUB) or isinstance(m, Conv2dWN) or isinstance(m, ConvTranspose2dWN) or \ isinstance(m, ConvTranspose2dWNUB) or isinstance(m, LinearWN): norm = np.sqrt(torch.sum(m.weight.data[:] ** 2)) m.g.data[:] = norm def initseq(s): for a, b in zip(s[:-1], s[1:]): if isinstance(b, nn.ReLU): initmod(a, nn.init.calculate_gain('relu')) elif isinstance(b, nn.LeakyReLU): initmod(a, nn.init.calculate_gain('leaky_relu', b.negative_slope)) elif isinstance(b, nn.Sigmoid): initmod(a) elif isinstance(b, nn.Softplus): initmod(a) else: initmod(a) initmod(s[-1]) ### custom modules class LinearWN(nn.Linear): def __init__(self, in_features, out_features, bias=True): super(LinearWN, self).__init__(in_features, out_features, bias) self.g = nn.Parameter(torch.ones(out_features)) self.fused = False def fuse(self): wnorm = torch.sqrt(torch.sum(self.weight ** 2)) self.weight.data = self.weight.data * self.g.data[:, None] / wnorm self.fused = True def forward(self, input): if self.fused: return F.linear(input, self.weight, self.bias) else: wnorm = torch.sqrt(torch.sum(self.weight ** 2)) return F.linear(input, self.weight * self.g[:, None] / wnorm, self.bias) class LinearELR(nn.Module): """Linear layer with equalized learning rate from stylegan2""" def __init__(self, inch, outch, lrmult=1., norm : Optional[str]=None, act=None): super(LinearELR, self).__init__() # compute gain from activation fn try: if isinstance(act, nn.LeakyReLU): actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) elif isinstance(act, nn.ReLU): actgain = nn.init.calculate_gain("relu") else: actgain = nn.init.calculate_gain(act) except: actgain = 1. initgain = 1. / math.sqrt(inch) self.weight = nn.Parameter(torch.randn(outch, inch) / lrmult) self.weightgain = actgain if norm == None: self.weightgain = self.weightgain * initgain * lrmult self.bias = nn.Parameter(torch.full([outch], 0.)) self.norm : Optional[str] = norm self.act = act self.fused = False def extra_repr(self): return 'inch={}, outch={}, norm={}, act={}'.format( self.weight.size(1), self.weight.size(0), self.norm, self.act ) def getweight(self): if self.fused: return self.weight else: weight = self.weight if self.norm is not None: if self.norm == "demod": weight = F.normalize(weight, dim=1) return weight def fuse(self): if not self.fused: with torch.no_grad(): self.weight.data = self.getweight() * self.weightgain self.fused = True def forward(self, x): if self.fused: weight = self.getweight() out = torch.addmm(self.bias[None], x, weight.t()) if self.act is not None: out = self.act(out) return out else: weight = self.getweight() if self.act is None: out = torch.addmm(self.bias[None], x, weight.t(), alpha=self.weightgain) return out else: out = F.linear(x, weight * self.weightgain, bias=self.bias) out = self.act(out) return out class Downsample2d(nn.Module): def __init__(self, nchannels, stride=1, padding=0): super(Downsample2d, self).__init__() self.nchannels = nchannels self.stride = stride self.padding = padding blurkernel = torch.tensor([1., 6., 15., 20., 15., 6., 1.]) blurkernel = (blurkernel[:, None] * blurkernel[None, :]) blurkernel = blurkernel / torch.sum(blurkernel) blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) self.register_buffer('kernel', blurkernel) def forward(self, x): if self.padding == "reflect": x = F.pad(x, (3, 3, 3, 3), mode='reflect') return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=0, groups=self.nchannels) else: return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels) class Dilate2d(nn.Module): def __init__(self, nchannels, kernelsize, stride=1, padding=0): super(Dilate2d, self).__init__() self.nchannels = nchannels self.kernelsize = kernelsize self.stride = stride self.padding = padding blurkernel = torch.ones((self.kernelsize,)) blurkernel = (blurkernel[:, None] * blurkernel[None, :]) blurkernel = blurkernel / torch.sum(blurkernel) blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) self.register_buffer('kernel', blurkernel) def forward(self, x): return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels).clamp(max=1.) class Conv2dWN(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, True) self.g = nn.Parameter(torch.ones(out_channels)) def forward(self, x): wnorm = torch.sqrt(torch.sum(self.weight ** 2)) return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) class Conv2dUB(nn.Conv2d): def __init__(self, in_channels, out_channels, height, width, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(Conv2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) def forward(self, x): return F.conv2d(x, self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + self.bias[None, ...] class Conv2dWNUB(nn.Conv2d): def __init__(self, in_channels, out_channels, height, width, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(Conv2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.g = nn.Parameter(torch.ones(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) def forward(self, x): wnorm = torch.sqrt(torch.sum(self.weight ** 2)) return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + self.bias[None, ...] def blockinit(k, stride): dim = k.ndim - 2 return k \ .view(k.size(0), k.size(1), *(x for i in range(dim) for x in (k.size(i+2), 1))) \ .repeat(1, 1, *(x for i in range(dim) for x in (1, stride))) \ .view(k.size(0), k.size(1), *(k.size(i+2)*stride for i in range(dim))) class ConvTranspose1dELR(nn.Module): def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): super(ConvTranspose1dELR, self).__init__() self.inch = inch self.outch = outch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.wsize = wsize self.norm = norm self.ub = ub self.act = act # compute gain from activation fn try: if isinstance(act, nn.LeakyReLU): actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) elif isinstance(act, nn.ReLU): actgain = nn.init.calculate_gain("relu") else: actgain = nn.init.calculate_gain(act) except: actgain = 1. fan_in = inch * (kernel_size / (stride)) initgain = stride ** 0.5 if norm == "demod" else 1. / math.sqrt(fan_in) self.weightgain = actgain * initgain self.weight = nn.Parameter(blockinit( torch.randn(inch, outch, kernel_size//self.stride), self.stride)) if ub is not None: self.bias = nn.Parameter(torch.zeros(outch, ub[0])) else: self.bias = nn.Parameter(torch.zeros(outch)) if wsize > 0: self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) else: self.affine = None self.fused = False def extra_repr(self): return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act ) def getweight(self, weight): if self.fused: return weight else: if self.norm is not None: if self.norm == "demod": if weight.ndim == 5: normdims = [1, 3] else: normdims = [0, 2] if torch.jit.is_scripting(): # scripting doesn't support F.normalize(..., dim=list[int]) weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) else: weight = F.normalize(weight, dim=normdims) weight = weight * self.weightgain return weight def fuse(self): if self.affine is None: with torch.no_grad(): self.weight.data = self.getweight(self.weight) self.fused = True def forward(self, x, w : Optional[torch.Tensor]=None): b = x.size(0) if self.affine is not None and w is not None: # modulate affine = self.affine(w)[:, :, None, None] # [B, inch, 1, 1] weight = self.weight * (affine * 0.1 + 1.) else: weight = self.weight weight = self.getweight(weight) if self.affine is not None and w is not None: x = x.view(1, b * self.inch, x.size(2)) weight = weight.view(b * self.inch, self.outch, self.kernel_size) groups = b else: groups = 1 out = F.conv_transpose1d(x, weight, None, stride=self.stride, padding=self.padding, dilation=1, groups=groups) if self.affine is not None and w is not None: out = out.view(b, self.outch, out.size(2)) if self.bias.ndim == 1: bias = self.bias[None, :, None] else: bias = self.bias[None, :, :] out = out + bias if self.act is not None: out = self.act(out) return out class ConvTranspose2dELR(nn.Module): def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): super(ConvTranspose2dELR, self).__init__() self.inch = inch self.outch = outch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.wsize = wsize self.norm = norm self.ub = ub self.act = act # compute gain from activation fn try: if isinstance(act, nn.LeakyReLU): actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) elif isinstance(act, nn.ReLU): actgain = nn.init.calculate_gain("relu") else: actgain = nn.init.calculate_gain(act) except: actgain = 1. fan_in = inch * (kernel_size ** 2 / (stride ** 2)) initgain = stride if norm == "demod" else 1. / math.sqrt(fan_in) self.weightgain = actgain * initgain self.weight = nn.Parameter(blockinit( torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride), self.stride)) if ub is not None: self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) else: self.bias = nn.Parameter(torch.zeros(outch)) if wsize > 0: self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) else: self.affine = None self.fused = False def extra_repr(self): return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act ) def getweight(self, weight): if self.fused: return weight else: if self.norm is not None: if self.norm == "demod": if weight.ndim == 5: normdims = [1, 3, 4] else: normdims = [0, 2, 3] if torch.jit.is_scripting(): # scripting doesn't support F.normalize(..., dim=list[int]) weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) else: weight = F.normalize(weight, dim=normdims) weight = weight * self.weightgain return weight def fuse(self): if self.affine is None: with torch.no_grad(): self.weight.data = self.getweight(self.weight) self.fused = True def forward(self, x, w : Optional[torch.Tensor]=None): b = x.size(0) if self.affine is not None and w is not None: # modulate affine = self.affine(w)[:, :, None, None, None] # [B, inch, 1, 1, 1] weight = self.weight * (affine * 0.1 + 1.) else: weight = self.weight weight = self.getweight(weight) if self.affine is not None and w is not None: x = x.view(1, b * self.inch, x.size(2), x.size(3)) weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size) groups = b else: groups = 1 out = F.conv_transpose2d(x, weight, None, stride=self.stride, padding=self.padding, dilation=1, groups=groups) if self.affine is not None and w is not None: out = out.view(b, self.outch, out.size(2), out.size(3)) if self.bias.ndim == 1: bias = self.bias[None, :, None, None] else: bias = self.bias[None, :, :, :] out = out + bias if self.act is not None: out = self.act(out) return out class ConvTranspose3dELR(nn.Module): def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): super(ConvTranspose3dELR, self).__init__() self.inch = inch self.outch = outch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.wsize = wsize self.norm = norm self.ub = ub self.act = act # compute gain from activation fn try: if isinstance(act, nn.LeakyReLU): actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) elif isinstance(act, nn.ReLU): actgain = nn.init.calculate_gain("relu") else: actgain = nn.init.calculate_gain(act) except: actgain = 1. fan_in = inch * (kernel_size ** 3 / (stride ** 3)) initgain = stride ** 1.5 if norm == "demod" else 1. / math.sqrt(fan_in) self.weightgain = actgain * initgain self.weight = nn.Parameter(blockinit( torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride, kernel_size//self.stride), self.stride)) if ub is not None: self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1], ub[2])) else: self.bias = nn.Parameter(torch.zeros(outch)) if wsize > 0: self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) else: self.affine = None self.fused = False def extra_repr(self): return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act ) def getweight(self, weight): if self.fused: return weight else: if self.norm is not None: if self.norm == "demod": if weight.ndim == 5: normdims = [1, 3, 4, 5] else: normdims = [0, 2, 3, 4] if torch.jit.is_scripting(): # scripting doesn't support F.normalize(..., dim=list[int]) weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) else: weight = F.normalize(weight, dim=normdims) weight = weight * self.weightgain return weight def fuse(self): if self.affine is None: with torch.no_grad(): self.weight.data = self.getweight(self.weight) self.fused = True def forward(self, x, w : Optional[torch.Tensor]=None): b = x.size(0) if self.affine is not None and w is not None: # modulate affine = self.affine(w)[:, :, None, None, None, None] # [B, inch, 1, 1, 1, 1] weight = self.weight * (affine * 0.1 + 1.) else: weight = self.weight weight = self.getweight(weight) if self.affine is not None and w is not None: x = x.view(1, b * self.inch, x.size(2), x.size(3), x.size(4)) weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size, self.kernel_size) groups = b else: groups = 1 out = F.conv_transpose3d(x, weight, None, stride=self.stride, padding=self.padding, dilation=1, groups=groups) if self.affine is not None and w is not None: out = out.view(b, self.outch, out.size(2), out.size(3), out.size(4)) if self.bias.ndim == 1: bias = self.bias[None, :, None, None, None] else: bias = self.bias[None, :, :, :, :] out = out + bias if self.act is not None: out = self.act(out) return out class Conv2dELR(nn.Module): def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): super(Conv2dELR, self).__init__() self.inch = inch self.outch = outch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.wsize = wsize self.norm = norm self.ub = ub self.act = act # compute gain from activation fn try: if isinstance(act, nn.LeakyReLU): actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) elif isinstance(act, nn.ReLU): actgain = nn.init.calculate_gain("relu") else: actgain = nn.init.calculate_gain(act) except: actgain = 1. fan_in = inch * (kernel_size ** 2) initgain = 1. if norm == "demod" else 1. / math.sqrt(fan_in) self.weightgain = actgain * initgain self.weight = nn.Parameter( torch.randn(outch, inch, kernel_size, kernel_size)) if ub is not None: self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) else: self.bias = nn.Parameter(torch.zeros(outch)) if wsize > 0: self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) else: self.affine = None self.fused = False def extra_repr(self): return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act ) def getweight(self, weight): if self.fused: return weight else: if self.norm is not None: if self.norm == "demod": if weight.ndim == 5: normdims = [2, 3, 4] else: normdims = [1, 2, 3] if torch.jit.is_scripting(): # scripting doesn't support F.normalize(..., dim=list[int]) weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) else: weight = F.normalize(weight, dim=normdims) weight = weight * self.weightgain return weight def fuse(self): if self.affine is None: with torch.no_grad(): self.weight.data = self.getweight(self.weight) self.fused = True def forward(self, x, w : Optional[torch.Tensor]=None): b = x.size(0) if self.affine is not None and w is not None: # modulate affine = self.affine(w)[:, None, :, None, None] # [B, 1, inch, 1, 1] weight = self.weight * (affine * 0.1 + 1.) else: weight = self.weight weight = self.getweight(weight) if self.affine is not None and w is not None: x = x.view(1, b * self.inch, x.size(2), x.size(3)) weight = weight.view(b * self.outch, self.inch, self.kernel_size, self.kernel_size) groups = b else: groups = 1 out = F.conv2d(x, weight, None, stride=self.stride, padding=self.padding, dilation=1, groups=groups) if self.affine is not None and w is not None: out = out.view(b, self.outch, out.size(2), out.size(3)) if self.bias.ndim == 1: bias = self.bias[None, :, None, None] else: bias = self.bias[None, :, :, :] out = out + bias if self.act is not None: out = self.act(out) return out class ConvTranspose2dWN(nn.ConvTranspose2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(ConvTranspose2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, True) self.g = nn.Parameter(torch.ones(out_channels)) self.fused = False def fuse(self): wnorm = torch.sqrt(torch.sum(self.weight ** 2)) self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm self.fused = True def forward(self, x): bias = self.bias assert bias is not None if self.fused: return F.conv_transpose2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) else: wnorm = torch.sqrt(torch.sum(self.weight ** 2)) return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) class ConvTranspose2dUB(nn.ConvTranspose2d): def __init__(self, width, height, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(ConvTranspose2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.bias_ = nn.Parameter(torch.zeros(out_channels, height, width)) def forward(self, x): return F.conv_transpose2d(x, self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + self.bias_[None, ...] class ConvTranspose2dWNUB(nn.ConvTranspose2d): def __init__(self, in_channels, out_channels, height, width, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(ConvTranspose2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.g = nn.Parameter(torch.ones(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) #self.biasf = nn.Parameter(torch.zeros(out_channels, height, width)) self.fused = False def fuse(self): wnorm = torch.sqrt(torch.sum(self.weight ** 2)) self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm self.fused = True def forward(self, x): bias = self.bias assert bias is not None if self.fused: return F.conv_transpose2d(x, self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + bias[None, ...] else: wnorm = torch.sqrt(torch.sum(self.weight ** 2)) return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + bias[None, ...] class Conv3dUB(nn.Conv3d): def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) def forward(self, x): return F.conv3d(x, self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + self.bias[None, ...] class ConvTranspose3dUB(nn.ConvTranspose3d): def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(ConvTranspose3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) def forward(self, x): return F.conv_transpose3d(x, self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + self.bias[None, ...] class Rodrigues(nn.Module): def __init__(self): super(Rodrigues, self).__init__() def forward(self, rvec): theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) rvec = rvec / theta[:, None] costh = torch.cos(theta) sinth = torch.sin(theta) return torch.stack(( rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) class Quaternion(nn.Module): def __init__(self): super(Quaternion, self).__init__() def forward(self, rvec): theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) rvec = rvec / theta[:, None] return torch.stack(( 1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2, 2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]), 2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]), 2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]), 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2, 2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]), 2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]), 2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]), 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2 ), dim=1).view(-1, 3, 3) class BufferDict(nn.Module): def __init__(self, d, persistent=False): super(BufferDict, self).__init__() for k in d: self.register_buffer(k, d[k], persistent=False) def __getitem__(self, key): return self._buffers[key] def __setitem__(self, key, parameter): self.register_buffer(key, parameter, persistent=False) def matrix_to_axisangle(r): th = torch.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.))[..., None] vec = 0.5 * torch.stack([ r[..., 2, 1] - r[..., 1, 2], r[..., 0, 2] - r[..., 2, 0], r[..., 1, 0] - r[..., 0, 1]], dim=-1) / torch.sin(th) return th, vec @torch.jit.script def axisangle_to_matrix(rvec : torch.Tensor): theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=-1)) rvec = rvec / theta[..., None] costh = torch.cos(theta) sinth = torch.sin(theta) return torch.stack(( torch.stack((rvec[..., 0] ** 2 + (1. - rvec[..., 0] ** 2) * costh, rvec[..., 0] * rvec[..., 1] * (1. - costh) - rvec[..., 2] * sinth, rvec[..., 0] * rvec[..., 2] * (1. - costh) + rvec[..., 1] * sinth), dim=-1), torch.stack((rvec[..., 0] * rvec[..., 1] * (1. - costh) + rvec[..., 2] * sinth, rvec[..., 1] ** 2 + (1. - rvec[..., 1] ** 2) * costh, rvec[..., 1] * rvec[..., 2] * (1. - costh) - rvec[..., 0] * sinth), dim=-1), torch.stack((rvec[..., 0] * rvec[..., 2] * (1. - costh) - rvec[..., 1] * sinth, rvec[..., 1] * rvec[..., 2] * (1. - costh) + rvec[..., 0] * sinth, rvec[..., 2] ** 2 + (1. - rvec[..., 2] ** 2) * costh), dim=-1)), dim=-2) def rotation_interp(r0, r1, alpha): r0a = r0.view(-1, 3, 3) r1a = r1.view(-1, 3, 3) r = torch.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) th, rvec = matrix_to_axisangle(r) rvec = rvec * (alpha * th) r = axisangle_to_matrix(rvec) return torch.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) def fuse(trainiter=None, renderoptions={}): def _fuse(m): if hasattr(m, "fuse") and isinstance(m, torch.nn.Module): if m.fuse.__code__.co_argcount > 1: m.fuse(trainiter, renderoptions) else: m.fuse() return _fuse def no_grad(m): for p in m.parameters(): p.requires_grad = False