# Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Mostly copy-paste from timm library. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ from copy import deepcopy import math from functools import partial from sympy import flatten import torch import torch.nn as nn from torch import Tensor, pixel_shuffle from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch.nn.modules import GELU # from vit.vision_transformer import Conv3DCrossAttentionBlock from .utils import trunc_normal_ from pdb import set_trace as st # import apex try: from apex.normalization import FusedRMSNorm as RMSNorm except: # from dit.norm import RMSNorm from dit.norm import RMSNorm # from apex.normalization import FusedLayerNorm as LayerNorm try: from xformers.ops import memory_efficient_attention, unbind, fmha from xformers.ops import MemoryEfficientAttentionFlashAttentionOp # from xformers.ops import RMSNorm XFORMERS_AVAILABLE = True except ImportError: # logger.warning("xFormers not available") XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., enable_rmsnorm=False, qk_norm=False,): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) # https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78 self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() # sd-3 self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() # if qk_norm: # self.q_norm = LayerNorm(dim, eps=1e-5) # self.k_norm = LayerNorm(dim, eps=1e-5) self.qk_norm = qk_norm def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) # return x, attn return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) q, k = self.q_norm(q), self.k_norm(k) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # if not bf16, no flash-attn here. # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class MemEffCrossAttention(MemEffAttention): # for cross attention, where context serves as k and v def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0): super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) del self.qkv self.q = nn.Linear(dim, dim * 1, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) def forward(self, x: Tensor, context: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x) B, N, C = x.shape # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q = self.q(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) kv = self.kv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) k, v = unbind(kv, 2) # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x # https://github.com/IBM/CrossViT/blob/main/models/crossvit.py class CrossAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim**-0.5 self.wq = nn.Linear(dim, dim, bias=qkv_bias) self.wk = nn.Linear(dim, dim, bias=qkv_bias) self.wv = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) attn = (q @ k.transpose( -2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape( B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C x = self.proj(x) x = self.proj_drop(x) return x class Conv3D_Aware_CrossAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim**-0.5 self.wq = nn.Linear(dim, dim, bias=qkv_bias) self.wk = nn.Linear(dim, dim, bias=qkv_bias) self.wv = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, group_size, N, C = x.shape # B 3 N C p = int(N**0.5) # patch size assert p**2 == N, 'check input dim, no [cls] needed here' assert group_size == 3, 'designed for triplane here' x = x.reshape(B, group_size, p, p, C) # expand patch token dim # * init qkv # q = torch.empty(B * group_size * N, # 1, # self.num_heads, # C // self.num_heads, # device=x.device).permute(0, 2, 1, 3) # k = torch.empty(B * group_size * N, # 2 * p, # self.num_heads, # C // self.num_heads, # device=x.device).permute(0, 2, 1, 3) # v = torch.empty_like(k) q_x = torch.empty( B * group_size * N, 1, # self.num_heads, # C // self.num_heads, C, device=x.device) k_x = torch.empty( B * group_size * N, 2 * p, # self.num_heads, # C // self.num_heads, C, device=x.device) v_x = torch.empty_like(k_x) # ! refer to the following plane order # N, M, _ = coordinates.shape # xy_coords = coordinates[..., [0, 1]] # yz_coords = coordinates[..., [1, 2]] # zx_coords = coordinates[..., [2, 0]] # return torch.stack([xy_coords, yz_coords, zx_coords], # dim=1).reshape(N * 3, M, 2) index_i, index_j = torch.meshgrid(torch.arange(0, p), torch.arange(0, p), indexing='ij') # 16*16 index_mesh_grid = torch.stack([index_i, index_j], 0).to( x.device).unsqueeze(0).repeat_interleave(B, 0).reshape(B, 2, p, p) # B 2 p p. for i in range(group_size): q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C # TODO, how to batchify gather ops? plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + 1] # B 1 p p C plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] assert plane_yz.shape == plane_zx.shape == ( B, 1, p, p, C), 'check sub plane dimensions' pooling_plane_yz = torch.gather( plane_yz, dim=2, index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( -1, -1, -1, p, C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C pooling_plane_zx = torch.gather( plane_zx, dim=3, index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( -1, -1, p, -1, C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C k_x[B * i * N:B * (i + 1) * N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat( [pooling_plane_yz, pooling_plane_zx], dim=2).reshape(B * N, 2 * p, C) # B 256 2 16 C => (B*256) 2*16 C # q[B * i * N: B * (i+1) * N] = self.wq(q_x).reshape(B*N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # k[B * i * N: B * (i+1) * N] = self.wk(k_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # v[B * i * N: B * (i+1) * N] = self.wv(v_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) # merge num_heads into Batch dimention k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) attn = (q @ k.transpose( -2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N, N=2p here attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape( B * 3 * N, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C x = self.proj(x) x = self.proj_drop(x) # reshape x back x = x.reshape(B, 3, N, C) return x class xformer_Conv3D_Aware_CrossAttention(nn.Module): # https://github.dev/facebookresearch/dinov2 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() # https://pytorch.org/blog/accelerated-generative-diffusion-models/ self.num_heads = num_heads self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias) self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.index_mesh_grid = None def forward(self, x, attn_bias=None): B, group_size, N, C = x.shape # B 3 N C p = int(N**0.5) # patch size assert p**2 == N, 'check input dim, no [cls] needed here' assert group_size == 3, 'designed for triplane here' x = x.reshape(B, group_size, p, p, C) # expand patch token dim q_x = torch.empty(B * group_size * N, 1, C, device=x.device) context = torch.empty(B * group_size * N, 2 * p, C, device=x.device) # k_x=v_x if self.index_mesh_grid is None: # further accelerate index_i, index_j = torch.meshgrid(torch.arange(0, p), torch.arange(0, p), indexing='ij') # 16*16 index_mesh_grid = torch.stack([index_i, index_j], 0).to( x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( B, 2, p, p) # B 2 p p. self.index_mesh_grid = index_mesh_grid[0:1] else: index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( B, 0) assert index_mesh_grid.shape == ( B, 2, p, p), 'check index_mesh_grid dimension' for i in range(group_size): q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C # TODO, how to batchify gather ops? plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + 1] # B 1 p p C plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] assert plane_yz.shape == plane_zx.shape == ( B, 1, p, p, C), 'check sub plane dimensions' pooling_plane_yz = torch.gather( plane_yz, dim=2, index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( -1, -1, -1, p, C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C pooling_plane_zx = torch.gather( plane_zx, dim=3, index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( -1, -1, p, -1, C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C context[B * i * N:B * (i + 1) * N] = torch.cat( [pooling_plane_yz, pooling_plane_zx], dim=2).reshape(B * N, 2 * p, C) # B 256 2 16 C => (B*256) 2*16 C # B, N, C = x.shape q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, C // self.num_heads) kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, self.num_heads, C // self.num_heads) k, v = unbind(kv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) x = self.proj(x) x = self.proj_drop(x) return x class xformer_Conv3D_Aware_CrossAttention_xygrid( xformer_Conv3D_Aware_CrossAttention): """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention """ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) def forward(self, x, attn_bias=None): B, group_size, N, C = x.shape # B 3 N C p = int(N**0.5) # patch size assert p**2 == N, 'check input dim, no [cls] needed here' assert group_size == 3, 'designed for triplane here' x = x.reshape(B, group_size, p, p, C) # expand patch token dim q_x = torch.empty(B * group_size * N, 1, C, device=x.device) context = torch.empty(B * group_size * N, 2 * p, C, device=x.device) # k_x=v_x if self.index_mesh_grid is None: # further accelerate index_u, index_v = torch.meshgrid( torch.arange(0, p), torch.arange(0, p), indexing='xy') # ! switch to 'xy' here to match uv coordinate index_mesh_grid = torch.stack([index_u, index_v], 0).to( x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( B, 2, p, p) # B 2 p p. self.index_mesh_grid = index_mesh_grid[0:1] else: index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( B, 0) assert index_mesh_grid.shape == ( B, 2, p, p), 'check index_mesh_grid dimension' for i in range(group_size): q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C # TODO, how to batchify gather ops? plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + 1] # B 1 p p C plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] assert plane_yz.shape == plane_zx.shape == ( B, 1, p, p, C), 'check sub plane dimensions' pooling_plane_yz = torch.gather( plane_yz, dim=2, index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand( -1, -1, -1, p, C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C pooling_plane_zx = torch.gather( plane_zx, dim=3, index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand( -1, -1, p, -1, C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C context[B * i * N:B * (i + 1) * N] = torch.cat( [pooling_plane_yz, pooling_plane_zx], dim=2).reshape(B * N, 2 * p, C) # B 256 2 16 C => (B*256) 2*16 C # B, N, C = x.shape q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, C // self.num_heads) kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, self.num_heads, C // self.num_heads) k, v = unbind(kv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) x = self.proj(x) x = self.proj_drop(x) return x class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( xformer_Conv3D_Aware_CrossAttention_xygrid): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0): super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) def forward(self, x, attn_bias=None): # ! split x: B N C into B 3 N C//3 B, N, C = x.shape x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1, 2) # B N C 3 -> B 3 N C x_out = super().forward(x, attn_bias) # B 3 N C x_out = x_out.permute(0, 2, 3, 1)# B 3 N C -> B N C 3 x_out = x_out.reshape(*x_out.shape[:2], -1) # B N C 3 -> B N C3 return x_out.contiguous() class self_cross_attn(nn.Module): def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dino_attn = dino_attn self.cross_attn = cross_attn def forward(self, x_norm): y = self.dino_attn(x_norm) + x_norm return self.cross_attn(y) # will add x in the original code # class RodinRollOutConv(nn.Module): # """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention # Use Group Conv # """ # def __init__(self, in_chans, out_chans=None): # super().__init__() # # input: B 3C H W # if out_chans is None: # out_chans = in_chans # self.roll_out_convs = nn.Conv2d(in_chans, # out_chans, # kernel_size=3, # groups=3, # padding=1) # def forward(self, x): # return self.roll_out_convs(x) class RodinRollOutConv3D(nn.Module): """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention """ def __init__(self, in_chans, out_chans=None): super().__init__() if out_chans is None: out_chans = in_chans self.out_chans = out_chans // 3 self.roll_out_convs = nn.Conv2d(in_chans, self.out_chans, kernel_size=3, padding=1) def forward(self, x): # todo, reshape before input? B, C3, p, p = x.shape # B 3C H W C = C3 // 3 group_size = C3 // C assert group_size == 3 x = x.reshape(B, 3, C, p, p) roll_out_x = torch.empty(B, group_size * C, p, 3 * p, device=x.device) # B, 3C, H, 3W for i in range(group_size): plane_xy = x[:, i] # B C H W # TODO, simply do the average pooling? plane_yz_pooling = x[:, (i + 1) % group_size].mean( dim=-1, keepdim=True).repeat_interleave( p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim plane_zx_pooling = x[:, (i + 2) % group_size].mean( dim=-2, keepdim=True).repeat_interleave( p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim roll_out_x[..., i * p:(i + 1) * p] = torch.cat( [plane_xy, plane_yz_pooling, plane_zx_pooling], 1) # fill in the 3W dim x = self.roll_out_convs(roll_out_x) # B C H 3W x = x.reshape(B, self.out_chans, p, 3, p) x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p, p) # B 3C H W return x class RodinRollOutConv3D_GroupConv(nn.Module): """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention """ def __init__(self, in_chans, out_chans=None, kernel_size=3, stride=1, padding=1): super().__init__() if out_chans is None: out_chans = in_chans self.roll_out_convs = nn.Conv2d( in_chans * 3, out_chans, kernel_size=kernel_size, groups=3, # B 9C H W stride=stride, padding=padding) # @torch.autocast(device_type='cuda') def forward(self, x): # todo, reshape before input? B, C3, p, p = x.shape # B 3C H W C = C3 // 3 group_size = C3 // C assert group_size == 3 x = x.reshape(B, 3, C, p, p) roll_out_x = torch.empty(B, group_size * C * 3, p, p, device=x.device) # B, 3C, H, 3W for i in range(group_size): plane_xy = x[:, i] # B C H W # # TODO, simply do the average pooling? plane_yz_pooling = x[:, (i + 1) % group_size].mean( dim=-1, keepdim=True).repeat_interleave( p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim plane_zx_pooling = x[:, (i + 2) % group_size].mean( dim=-2, keepdim=True).repeat_interleave( p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( [plane_xy, plane_yz_pooling, plane_zx_pooling], 1) # fill in the 3W dim # ! directly cat, avoid intermediate vars # ? why OOM # roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( # [ # x[:, i], # x[:, (i + 1) % group_size].mean( # dim=-1, keepdim=True).repeat_interleave(p, dim=-1), # x[:, (i + 2) % group_size].mean( # dim=-2, keepdim=True).repeat_interleave( # p, dim=-2 # ) # B C H W -> B C 1 W -> B C H W, reduce z dim # ], # 1) # fill in the 3C dim x = self.roll_out_convs(roll_out_x) # B 3C H W return x class RodinRollOut_GroupConv_noConv3D(nn.Module): """only roll out and do Conv on individual planes """ def __init__(self, in_chans, out_chans=None, kernel_size=3, stride=1, padding=1): super().__init__() if out_chans is None: out_chans = in_chans self.roll_out_inplane_conv = nn.Conv2d( in_chans, out_chans, kernel_size=kernel_size, groups=3, # B 3C H W stride=stride, padding=padding) def forward(self, x): x = self.roll_out_inplane_conv(x) # B 3C H W return x # class RodinConv3D_SynthesisLayer_withact(nn.Module): # def __init__(self, in_chans, out_chans) -> None: # super().__init__() # self.act = nn.LeakyReLU(inplace=True) # self.conv = nn.Sequential( # RodinRollOutConv3D_GroupConv(in_chans, out_chans), # nn.LeakyReLU(inplace=True), # ) # if in_chans != out_chans: # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. # else: # self.short_cut = None # def forward(self, feats): # if self.short_cut is not None: # res_feats = self.short_cut(feats) # else: # res_feats = feats # # return res_feats + self.conv(feats) # feats = res_feats + self.conv(feats) # return self.act(feats) # as in resnet, add an act before return class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module): def __init__(self, in_chans, out_chans) -> None: super().__init__() self.act = nn.LeakyReLU(inplace=True) self.conv = nn.Sequential( RodinRollOutConv3D_GroupConv(in_chans, out_chans), nn.LeakyReLU(inplace=True), ) self.out_chans = out_chans if in_chans != out_chans: # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W in_chans // 3, # 144 / 3 = 48 out_chans // 3 * 4 * 4, # 32 * 16 bias=True) # decoder to pat # RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. else: self.short_cut = None def shortcut_unpatchify_triplane(self, x, p=None, unpatchify_out_chans=None): """separate triplane version; x shape: B (3*257) 768 """ assert self.short_cut is not None # B, L, C = x.shape B, C3, h, w = x.shape assert h == w L = h * w x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, 1) # (B, 3, L // 3, C) x = self.short_cut(x) p = h * 4 x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans)) x = torch.einsum('ndhwpqc->ndchpwq', x) # nplanes, C order in the renderer.py x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p)) return x def forward(self, feats): if self.short_cut is not None: res_feats = self.shortcut_unpatchify_triplane(feats) else: res_feats = feats # return res_feats + self.conv(feats) feats = res_feats + self.conv(feats) return self.act(feats) # as in resnet, add an act before return # class RodinConv3D_SynthesisLayer(nn.Module): # def __init__(self, in_chans, out_chans) -> None: # super().__init__() # self.act = nn.LeakyReLU(inplace=True) # self.conv = nn.Sequential( # RodinRollOutConv3D_GroupConv(in_chans, out_chans), # nn.LeakyReLU(inplace=True), # ) # if in_chans != out_chans: # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. # else: # self.short_cut = None # def forward(self, feats): # if self.short_cut is not None: # res_feats = self.short_cut(feats) # else: # res_feats = feats # # return res_feats + self.conv(feats) # feats = res_feats + self.conv(feats) # # return self.act(feats) # as in resnet, add an act before return # return feats # ! old behaviour, no act # previous worked version class RodinConv3D_SynthesisLayer(nn.Module): def __init__(self, in_chans, out_chans) -> None: super().__init__() # x2 SR + 1x1 Conv Residual BLK # self.conv3D = RodinRollOutConv3D(in_chans, out_chans) self.act = nn.LeakyReLU(inplace=True) self.conv = nn.Sequential( RodinRollOutConv3D_GroupConv(in_chans, out_chans), nn.LeakyReLU(inplace=True), ) if in_chans != out_chans: self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) else: self.short_cut = None def forward(self, feats): feats_out = self.conv(feats) if self.short_cut is not None: # ! failed below feats_out = self.short_cut( feats ) + feats_out # ! only difference here, no act() compared with baseline # feats_out = self.act(self.short_cut(feats)) + feats_out # ! only difference here, no act() compared with baseline else: feats_out = feats_out + feats return feats_out class RodinRollOutConv3DSR2X(nn.Module): def __init__(self, in_chans, **kwargs) -> None: super().__init__() self.conv3D = RodinRollOutConv3D_GroupConv(in_chans) # self.conv3D = RodinRollOutConv3D(in_chans) self.act = nn.LeakyReLU(inplace=True) self.input_resolution = 224 def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 group_size = C3 // C assert group_size == 3 # p = int(N**0.5) # patch size # assert p**2 == N, 'check input dim, no [cls] needed here' assert group_size == 3, 'designed for triplane here' x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if x.shape[-1] != self.input_resolution: x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) x = x + self.conv3D(x) return x class RodinRollOutConv3DSR4X_lite(nn.Module): def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None: super().__init__() self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans) self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans) self.act = nn.LeakyReLU(inplace=True) self.input_resolution = input_resolutiopn def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 group_size = C3 // C assert group_size == 3 # p = int(N**0.5) # patch size # assert p**2 == N, 'check input dim, no [cls] needed here' assert group_size == 3, 'designed for triplane here' x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if x.shape[-1] != self.input_resolution: x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) # ! still not convering, not bug here? # x = x + self.conv3D_0(x) # x = x + self.conv3D_1(x) x = x + self.act(self.conv3D_0(x)) x = x + self.act(self.conv3D_1(x)) # TODO: which is better, bilinear + conv or PixelUnshuffle? return x # class RodinConv3D2X_lite_mlp_as_residual(nn.Module): # """lite 4X version, with MLP unshuffle to change the dimention # """ # def __init__(self, in_chans, out_chans, input_resolution=256) -> None: # super().__init__() # self.act = nn.LeakyReLU(inplace=True) # self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) # self.act = nn.LeakyReLU(inplace=True) # self.input_resolution = input_resolution # self.out_chans = out_chans # if in_chans != out_chans: # ! only change the dimension # self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W # in_chans//3, # 144 / 3 = 48 # out_chans//3, # 32 * 16 # bias=True) # decoder to pat # else: # self.short_cut = None # def shortcut_unpatchify_triplane(self, x, p=None): # """separate triplane version; x shape: B (3*257) 768 # """ # assert self.short_cut is not None # # B, L, C = x.shape # B, C3, h, w = x.shape # assert h == w # L = h*w # x = x.reshape(B, C3//3, 3, L).permute(0,2,3,1) # (B, 3, L // 3, C_in) # x = self.short_cut(x) # B 3 L//3 C_out # x = x.permute(0,1,3,2) # B 3 C_out L//3 # x = x.reshape(shape=(B, self.out_chans, h, w)) # # directly resize to the target, no unpatchify here since no 3D ViT is included here # if w != self.input_resolution: # x = torch.nn.functional.interpolate(x, # 4X SR # size=(self.input_resolution, # self.input_resolution), # mode='bilinear', # align_corners=False, # antialias=True) # return x # def forward(self, x): # # x: B 3 112*112 C # B, C3, p, p = x.shape # after unpachify triplane # C = C3 // 3 # if self.short_cut is not None: # res_feats = self.shortcut_unpatchify_triplane(x) # else: # res_feats = x # """following forward code copied from lite4x version # """ # x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, # p) # B 3 C N -> B 3C h W # if x.shape[-1] != self.input_resolution: # x = torch.nn.functional.interpolate(x, # 4X SR # size=(self.input_resolution, # self.input_resolution), # mode='bilinear', # align_corners=False, # antialias=True) # x = res_feats + self.act(self.conv3D_0(x)) # x = x + self.act(self.conv3D_1(x)) # return x class RodinConv3D4X_lite_mlp_as_residual(nn.Module): """lite 4X version, with MLP unshuffle to change the dimention """ def __init__(self, in_chans, out_chans, input_resolution=256, interp_mode='bilinear', bcg_triplane=False) -> None: super().__init__() self.interp_mode = interp_mode self.act = nn.LeakyReLU(inplace=True) self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) self.bcg_triplane = bcg_triplane if bcg_triplane: self.conv3D_1_bg = RodinRollOutConv3D_GroupConv( out_chans, out_chans) self.act = nn.LeakyReLU(inplace=True) self.input_resolution = input_resolution self.out_chans = out_chans if in_chans != out_chans: # ! only change the dimension self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W in_chans // 3, # 144 / 3 = 48 out_chans // 3, # 32 * 16 bias=True) # decoder to pat else: self.short_cut = None def shortcut_unpatchify_triplane(self, x, p=None): """separate triplane version; x shape: B (3*257) 768 """ assert self.short_cut is not None B, C3, h, w = x.shape assert h == w L = h * w x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, 1) # (B, 3, L // 3, C_in) x = self.short_cut(x) # B 3 L//3 C_out x = x.permute(0, 1, 3, 2) # B 3 C_out L//3 x = x.reshape(shape=(B, self.out_chans, h, w)) # directly resize to the target, no unpatchify here since no 3D ViT is included here if w != self.input_resolution: x = torch.nn.functional.interpolate( x, # 4X SR size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) return x def interpolate(self, feats): if self.interp_mode == 'bilinear': return torch.nn.functional.interpolate( feats, # 4X SR size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) else: return torch.nn.functional.interpolate( feats, # 4X SR size=(self.input_resolution, self.input_resolution), mode='nearest', ) def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 if self.short_cut is not None: res_feats = self.shortcut_unpatchify_triplane(x) else: res_feats = x if res_feats.shape[-1] != self.input_resolution: res_feats = self.interpolate(res_feats) """following forward code copied from lite4x version """ x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if x.shape[-1] != self.input_resolution: x = self.interpolate(x) x0 = res_feats + self.act(self.conv3D_0(x)) # the base feature x = x0 + self.act(self.conv3D_1(x0)) if self.bcg_triplane: x_bcg = x0 + self.act(self.conv3D_1_bg(x0)) return torch.cat([x, x_bcg], 1) else: return x class RodinConv3D4X_lite_mlp_as_residual_litev2( RodinConv3D4X_lite_mlp_as_residual): def __init__(self, in_chans, out_chans, num_feat=128, input_resolution=256, interp_mode='bilinear', bcg_triplane=False) -> None: super().__init__(in_chans, out_chans, input_resolution, interp_mode, bcg_triplane) self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans) self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D( in_chans, num_feat * 3) self.conv3D_1 = RodinRollOut_GroupConv_noConv3D( num_feat * 3, num_feat * 3) self.conv_last = RodinRollOut_GroupConv_noConv3D( num_feat * 3, out_chans) self.short_cut = None def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 # if self.short_cut is not None: # res_feats = self.shortcut_unpatchify_triplane(x) # else: # res_feats = x # if res_feats.shape[-1] != self.input_resolution: # res_feats = self.interpolate(res_feats) """following forward code copied from lite4x version """ x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W x = x + self.conv3D_0(x) # the base feature x = self.act(self.conv_before_upsample(x)) # if x.shape[-1] != self.input_resolution: x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x)))) return x class RodinConv3D4X_lite_mlp_as_residual_lite( RodinConv3D4X_lite_mlp_as_residual): def __init__(self, in_chans, out_chans, input_resolution=256, interp_mode='bilinear') -> None: super().__init__(in_chans, out_chans, input_resolution, interp_mode) """replace the first Rodin Conv 3D with ordinary rollout conv to save memory """ self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans) class SR3D(nn.Module): # https://github.com/SeanChenxy/Mimic3D/blob/77d313656df3cd5536d2c4c5766db3a56208eea6/training/networks_stylegan2.py#L629 # roll-out and apply two deconv/pixelUnshuffle layer def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module): def __init__(self, in_chans, num_feat, out_chans, input_resolution=256) -> None: super().__init__() assert in_chans == 4 * out_chans assert num_feat == 2 * out_chans self.input_resolution = input_resolution # refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 self.upscale = 4 self.conv_after_body = RodinRollOutConv3D_GroupConv( in_chans, in_chans, 3, 1, 1) self.conv_before_upsample = nn.Sequential( RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, 1) if self.upscale == 4: self.conv_up2 = RodinRollOutConv3D_GroupConv( num_feat, num_feat, 3, 1, 1) self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, 1) self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 """following forward code copied from lite4x version """ x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W # ? nearest or bilinear x = self.conv_after_body(x) + x x = self.conv_before_upsample(x) x = self.lrelu( self.conv_up1( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest', # align_corners=False, # antialias=True ))) if self.upscale == 4: x = self.lrelu( self.conv_up2( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest', # align_corners=False, # antialias=True ))) x = self.conv_last(self.lrelu(self.conv_hr(x))) assert x.shape[-1] == self.input_resolution return x class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module): def __init__(self, in_chans, num_feat, out_chans, input_resolution=256) -> None: super().__init__() assert in_chans == 4 * out_chans assert num_feat == 2 * out_chans self.input_resolution = input_resolution # refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 self.upscale = 4 self.conv_after_body = RodinRollOutConv3D_GroupConv( in_chans, in_chans, 3, 1, 1) self.conv_before_upsample = nn.Sequential( RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, 1) if self.upscale == 4: self.conv_up2 = RodinRollOutConv3D_GroupConv( num_feat, num_feat, 3, 1, 1) self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, 1) self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 """following forward code copied from lite4x version """ x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W # ? nearest or bilinear x = self.conv_after_body(x) + x x = self.conv_before_upsample(x) x = self.lrelu( self.conv_up1( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest', # align_corners=False, # antialias=True ))) if self.upscale == 4: x = self.lrelu( self.conv_up2( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest', # align_corners=False, # antialias=True ))) x = self.conv_last(self.lrelu(self.conv_hr(x) + x)) assert x.shape[-1] == self.input_resolution return x class RodinRollOutConv3DSR_FlexibleChannels(nn.Module): def __init__(self, in_chans, num_out_ch=96, input_resolution=256, **kwargs) -> None: super().__init__() self.block0 = RodinConv3D_SynthesisLayer(in_chans, num_out_ch) # in_chans=48 self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch) self.input_resolution = input_resolution # 64 -> 256 SR def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 # group_size = C3 // C x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if x.shape[-1] != self.input_resolution: x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) x = self.block0(x) x = self.block1(x) return x # previous worked version class RodinRollOutConv3DSR4X(nn.Module): # follow PixelUnshuffleUpsample def __init__(self, in_chans, **kwargs) -> None: super().__init__() # self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96 * 2) # TODO, match the old behaviour now. # self.block1 = RodinConv3D_SynthesisLayer(96 * 2, 96) self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96) self.block1 = RodinConv3D_SynthesisLayer( 96, 96) # baseline choice, validate with no LPIPS loss here self.input_resolution = 64 # 64 -> 256 def forward(self, x): # x: B 3 112*112 C B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 # group_size = C3 // C x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if x.shape[-1] != self.input_resolution: x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) x = self.block0(x) x = self.block1(x) return x class Upsample3D(nn.Module): """Upsample module. Args: scale (int): Scale factor. Supported scales: 2^n and 3. num_feat (int): Channel number of intermediate features. """ def __init__(self, scale, num_feat): super().__init__() m_convs = [] m_pixelshuffle = [] assert (scale & (scale - 1)) == 0, 'scale = 2^n' self.scale = scale for _ in range(int(math.log(scale, 2))): m_convs.append( RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1)) m_pixelshuffle.append(nn.PixelShuffle(2)) self.m_convs = nn.ModuleList(m_convs) self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle) # @torch.autocast(device_type='cuda') def forward(self, x): for scale_idx in range(int(math.log(self.scale, 2))): x = self.m_convs[scale_idx](x) # B 3C H W # x = # B, C3, H, W = x.shape x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:]) x = self.m_pixelshuffle[scale_idx](x) x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:]) return x class RodinConv3DPixelUnshuffleUpsample(nn.Module): def __init__(self, output_dim, num_feat=32 * 6, num_out_ch=32 * 3, sr_ratio=4, *args, **kwargs) -> None: super().__init__() self.conv_after_body = RodinRollOutConv3D_GroupConv( output_dim, output_dim, 3, 1, 1) self.conv_before_upsample = nn.Sequential( RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.upsample = Upsample3D(sr_ratio, num_feat) # 4 time SR self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3, 1, 1) # @torch.autocast(device_type='cuda') def forward(self, x, input_skip_connection=True, *args, **kwargs): # x = self.conv_first(x) if input_skip_connection: x = self.conv_after_body(x) + x else: x = self.conv_after_body(x) x = self.conv_before_upsample(x) x = self.upsample(x) x = self.conv_last(x) return x class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module): def __init__( self, output_dim, num_out_ch=32 * 3, sr_ratio=4, input_resolution=256, ) -> None: super().__init__() self.input_resolution = input_resolution # self.conv_first = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, # 3, 1, 1) self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, 3, 1, 1) def forward(self, x, bilinear_upsample=True): B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 group_size = C3 // C assert group_size == 3, 'designed for triplane here' x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if bilinear_upsample and x.shape[-1] != self.input_resolution: x_bilinear_upsample = torch.nn.functional.interpolate( x, size=(self.input_resolution, self.input_resolution), mode='bilinear', align_corners=False, antialias=True) x = self.upsample(x) + x_bilinear_upsample else: # x_bilinear_upsample = x x = self.upsample(x) x = self.conv_last(x) return x class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module): """removed nearest neighbour residual conenctions, add a conv layer residual conenction """ def __init__( self, output_dim, num_out_ch=32 * 3, sr_ratio=4, input_resolution=256, ) -> None: super().__init__() self.input_resolution = input_resolution self.conv_after_body = RodinRollOutConv3D_GroupConv( output_dim, num_out_ch, 3, 1, 1) self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, 3, 1, 1) def forward(self, x, input_skip_connection=True): B, C3, p, p = x.shape # after unpachify triplane C = C3 // 3 group_size = C3 // C assert group_size == 3, 'designed for triplane here' x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, p) # B 3 C N -> B 3C h W if input_skip_connection: x = self.conv_after_body(x) + x else: x = self.conv_after_body(x) x = self.upsample(x) x = self.conv_last(x) return x class CLSCrossAttentionBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.has_mlp = has_mlp if has_mlp: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) if self.has_mlp: x = x + self.drop_path(self.mlp(self.norm2(x))) return x class Conv3DCrossAttentionBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = Conv3D_Aware_CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.has_mlp = has_mlp if has_mlp: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) if self.has_mlp: x = x + self.drop_path(self.mlp(self.norm2(x))) return x class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock): def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop=0, attn_drop=0, drop_path=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False): super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer, has_mlp) # self.attn = xformer_Conv3D_Aware_CrossAttention(dim, self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) class Conv3DCrossAttentionBlockXformerMHANested( Conv3DCrossAttentionBlockXformerMHA): def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False): super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer, has_mlp) """for in-place replaing the internal attn in Dino ViT. """ def forward(self, x): Bx3, N, C = x.shape B, group_size = Bx3 // 3, 3 x = x.reshape(B, group_size, N, C) # in plane vit x = super().forward(x) return x.reshape(B * group_size, N, C) # to match the original attn size class Conv3DCrossAttentionBlockXformerMHANested_withinC( Conv3DCrossAttentionBlockXformerMHANested): def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop=0, attn_drop=0, drop_path=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False): super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer, has_mlp) self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) def forward(self, x): # basic TX attention forward function x = x + self.drop_path(self.attn(self.norm1(x))) if self.has_mlp: x = x + self.drop_path(self.mlp(self.norm2(x))) return x class TriplaneFusionBlock(nn.Module): """4 ViT blocks + 1 CrossAttentionBlock """ def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, cross_attention_blk=CLSCrossAttentionBlock, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.num_branches = 3 # triplane self.vit_blks = vit_blks if use_fusion_blk: self.fusion = nn.ModuleList() # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check for d in range(self.num_branches): self.fusion.append( cross_attention_blk( dim=dim, num_heads=nh, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop, drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path_rate, norm_layer=norm_layer, # type: ignore has_mlp=False)) else: self.fusion = None def forward(self, x): # modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' x = x.view(B * group_size, N, C) for blk in self.vit_blks: x = blk(x) # B 3 N C if self.fusion is None: return x.view(B, group_size, N, C) # outs_b = x.view(B, group_size, N, # C).chunk(chunks=3, # dim=1) # 3 * [B, 1, N//3, C] Tensors, for fusion outs_b = x.chunk(chunks=3, dim=0) # 3 * [B, N//3, C] Tensors, for fusion # only take the cls token out proj_cls_token = [x[:, 0:1] for x in outs_b] # cross attention outs = [] for i in range(self.num_branches): tmp = torch.cat( (proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) tmp = self.fusion[i](tmp) # reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) reverted_proj_cls_token = tmp[:, 0:1, ...] tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) outs.append(tmp) # outs = ? needs to merge back? outs = torch.stack(outs, 1) # B 3 N C return outs class TriplaneFusionBlockv2(nn.Module): """4 ViT blocks + 1 CrossAttentionBlock """ def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlock, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.num_branches = 3 # triplane self.vit_blks = vit_blks if use_fusion_blk: # self.fusion = nn.ModuleList() # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check # for d in range(self.num_branches): self.fusion = fusion_ca_blk( # one fusion is enough dim=dim, num_heads=nh, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop, drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path_rate, norm_layer=norm_layer, # type: ignore has_mlp=False) else: self.fusion = None def forward(self, x): # modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' x = x.reshape(B * group_size, N, C) for blk in self.vit_blks: x = blk(x) # B 3 N C if self.fusion is None: return x.reshape(B, group_size, N, C) x = x.reshape(B, group_size, N, C) # .chunk(chunks=3, # dim=1) # 3 * [B, N//3, C] Tensors, for fusion return self.fusion(x) class TriplaneFusionBlockv3(TriplaneFusionBlockv2): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, *args, **kwargs) class TriplaneFusionBlockv4(TriplaneFusionBlockv3): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, *args, **kwargs) """OOM? directly replace the atten here """ assert len(vit_blks) == 2 # del self.vit_blks[1].attn del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1 def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor: return tx_blk.ls2( tx_blk.mlp(tx_blk.norm2(x)) ) # https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/layers/block.py#L86C1-L87C53 def forward(self, x): """x: B 3 N C, where N = H*W tokens """ assert self.fusion is not None B, group_size, N, C = x.shape # has [cls] token in N x = x.reshape(B * group_size, N, C) # in plane vit # in plane self attention x = self.vit_blks[0](x) # 3D cross attention blk + ffn x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape( B * group_size, N, C) x = x + self.ffn_residual_func(self.vit_blks[1], x) return x.reshape(B, group_size, N, C) class TriplaneFusionBlockv4_nested(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, *args, **kwargs) -> None: super().__init__() self.num_branches = 3 # triplane self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # ! replace vit_blks[1] attn layer with 3D aware attention del self.vit_blks[ 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check self.vit_blks[1].attn = fusion_ca_blk( # one fusion is enough dim=dim, num_heads=nh, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop, drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path_rate, norm_layer=norm_layer, # type: ignore has_mlp=False) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' x = x.reshape(B * group_size, N, C) for blk in self.vit_blks: x = blk(x) # B 3 N C # TODO, avoid the reshape overhead? return x.reshape(B, group_size, N, C) class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: super().__init__() self.num_branches = 3 # triplane self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check attn_3d = fusion_ca_blk( # one fusion is enough dim=dim, num_heads=nh, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop, drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path_rate, norm_layer=norm_layer, # type: ignore has_mlp=False) # ! initialize 3dattn from dino attn if init_from_dino: merged_qkv_linear = self.vit_blks[1].attn.qkv attn_3d.attn.proj.load_state_dict( self.vit_blks[1].attn.proj.state_dict()) # Initialize the Q, K, and V linear layers using the weights of the merged QKV linear layer attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[: dim, :] attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[ dim:, :] # Optionally, you can initialize the biases as well (if your QKV linear layer has biases) if qkv_bias: attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim] attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:] del self.vit_blks[1].attn # ! assign self.vit_blks[1].attn = attn_3d def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' x = x.reshape(B * group_size, N, C) for blk in self.vit_blks: x = blk(x) # B 3 N C # TODO, avoid the reshape overhead? return x.reshape(B, group_size, N, C) class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=None, *args, **kwargs) -> None: super().__init__() self.num_branches = 3 # triplane self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) del self.vit_blks[1].attn # ! assign self.vit_blks[1].attn = attn_3d def forward(self, x): """x: B N C, where N = H*W tokens. Just raw ViT forward pass """ # ! move the below to the front of the first call B, N, C = x.shape # has [cls] token in N for blk in self.vit_blks: x = blk(x) # B N C return x class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=None, *args, **kwargs) -> None: super().__init__() self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim qkv_bias = True attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check if False: # abla for blk in self.vit_blks: attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) blk.attn = self_cross_attn(blk.attn, attn_3d) def forward(self, x): """x: B N C, where N = H*W tokens. Just raw ViT forward pass """ # ! move the below to the front of the first call B, N, C = x.shape # has [cls] token in N for blk in self.vit_blks: x = blk(x) # B N C return x class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): # on roll out + B 3L C def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # ! move the below to the front of the first call # B, N, C = x.shape # has [cls] token in N B, group_size, N, C = x.shape # has [cls] token in N x = x.reshape(B, group_size*N, C) for blk in self.vit_blks: x = blk(x) # B N C x = x.reshape(B, group_size, N, C) # outer loop tradition return x class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): # roll out + B 3L C def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # ! move the below to the front of the first call # B, N, C = x.shape # has [cls] token in N B, group_size, N, C = x.shape # has [cls] token in N x = x.reshape(B*group_size, N, C) x = self.vit_blks[0](x) x = x.reshape(B,group_size*N, C) x = self.vit_blks[1](x) x = x.reshape(B, group_size, N, C) # outer loop tradition return x class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino): # no roll out + 3D Attention def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ B, group_size, N, C = x.shape # has [cls] token in N x = x.reshape(B, group_size*N, C) x = self.vit_blks[0](x) # B 3 L C # ! move the below to the front of the first call x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C) x = self.vit_blks[1](x) # has 3D attention return x.reshape(B, group_size, N, C) return x class TriplaneFusionBlockv5_ldm_addCA(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, *args, **kwargs) -> None: super().__init__() self.num_branches = 3 # triplane self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # ! rather than replacing, add a 3D attention block after. # del self.vit_blks[ # 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1) # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' flatten_token = lambda x: x.reshape(B * group_size, N, C) unflatten_token = lambda x: x.reshape(B, group_size, N, C) x = flatten_token(x) x = self.vit_blks[0](x) x = unflatten_token(x) x = self.attn_3d(self.norm_for_atten_3d(x)) + x x = flatten_token(x) x = self.vit_blks[1](x) return unflatten_token(x) class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D( TriplaneFusionBlockv5_ldm_addCA): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, *args, **kwargs) -> None: super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, *args, **kwargs) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' flatten_token = lambda x: x.reshape(B * group_size, N, C) unflatten_token = lambda x: x.reshape(B, group_size, N, C) x = flatten_token(x) x = self.vit_blks[0](x) x = unflatten_token(x) x = self.attn_3d(self.norm_for_atten_3d(x)) + x x = flatten_token(x) x = self.vit_blks[1](x) return unflatten_token(x) class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module): def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, *args, **kwargs) -> None: super().__init__() self.num_branches = 3 # triplane self.vit_blks = vit_blks assert use_fusion_blk assert len(vit_blks) == 2 # ! rather than replacing, add a 3D attention block after. # del self.vit_blks[ # 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1) self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1) # copied vit settings from https://github.dev/facebookresearch/dinov2 nh = num_heads dim = embed_dim mlp_ratio = 4 # defined for all dino2 model qkv_bias = True norm_layer = partial(nn.LayerNorm, eps=1e-6) drop_path_rate = 0.3 # default setting attn_drop = proj_drop = 0.0 qk_scale = None # TODO, double check self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) self.attn_3d_1 = deepcopy(self.attn_3d_0) def forward(self, x): """x: B 3 N C, where N = H*W tokens """ # self attention, by merging the triplane channel into B for parallel computation # ! move the below to the front of the first call B, group_size, N, C = x.shape # has [cls] token in N assert group_size == 3, 'triplane' flatten_token = lambda x: x.reshape(B * group_size, N, C) unflatten_token = lambda x: x.reshape(B, group_size, N, C) x = flatten_token(x) x = self.vit_blks[0](x) x = unflatten_token(x) x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x x = flatten_token(x) x = self.vit_blks[1](x) x = unflatten_token(x) x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x return unflatten_token(x) def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0], ) + (1, ) * ( x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand( shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) # self.attn = Attention(dim, self.attn = MemEffAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False): y, attn = self.attn(self.norm1(x)) if return_attention: return attn x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) # B, C, L -> B, L, C return x class VisionTransformer(nn.Module): """ Vision Transformer """ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer='nn.LayerNorm', patch_embedding=True, cls_token=True, pixel_unshuffle=False, **kwargs): super().__init__() self.num_features = self.embed_dim = embed_dim self.patch_size = patch_size # if norm_layer == 'nn.LayerNorm': norm_layer = partial(nn.LayerNorm, eps=1e-6) if patch_embedding: self.patch_embed = PatchEmbed(img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.img_size = self.patch_embed.img_size else: self.patch_embed = None self.img_size = img_size[0] num_patches = (img_size[0] // patch_size) * (img_size[0] // patch_size) if cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dim)) else: self.cls_token = None self.pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim) # Classifier head self.head = nn.Linear( embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) if cls_token: trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) # if pixel_unshuffle: # self.decoder_pred = nn.Linear(embed_dim, # patch_size**2 * out_chans, # bias=True) # decoder to patch def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int( h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim) if self.cls_token is not None: class_pos_embed = self.pos_embed[:, 0] return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) return patch_pos_embed def prepare_tokens(self, x): B, nc, w, h = x.shape x = self.patch_embed(x) # patch linear embedding # add the [CLS] token to the embed patch tokens cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # add positional encoding to each token x = x + self.interpolate_pos_encoding(x, w, h) return self.pos_drop(x) def forward(self, x): x = self.prepare_tokens(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x[:, 1:] # return spatial feature maps, not the [CLS] token # return x[:, 0] def get_last_selfattention(self, x): x = self.prepare_tokens(x) for i, blk in enumerate(self.blocks): if i < len(self.blocks) - 1: x = blk(x) else: # return attention of the last block return blk(x, return_attention=True) def get_intermediate_layers(self, x, n=1): x = self.prepare_tokens(x) # we return the output tokens from the `n` last blocks output = [] for i, blk in enumerate(self.blocks): x = blk(x) if len(self.blocks) - i <= n: output.append(self.norm(x)) return output def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer(patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_small(patch_size=16, **kwargs): model = VisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore **kwargs) return model def vit_base(patch_size=16, **kwargs): model = VisionTransformer(patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model vits = vit_small vitb = vit_base