from functools import partial from typing import Tuple, List, Optional import torch from torch import Tensor, nn from mmengine.model import BaseModule, normal_init from mmdet.registry import MODELS from mmdet.models.layers import PatchEmbed from ext.meta.sam_meta import checkpoint_dict from ext.sam.common import LayerNorm2d from ext.sam.image_encoder import Block from utils.load_checkpoint import load_checkpoint_with_prefix @MODELS.register_module() class MultiLayerTransformerNeck(BaseModule): STRIDE = 16 def __init__( self, input_size: Tuple[int, int], in_channels: List[int], embed_channels: int, out_channels: int, layer_ids: Tuple[int] = (0, 1, 2, 3), strides: Tuple[int] = (4, 8, 16, 32), embedding_path: Optional[str] = None, fix=False, init_cfg=None ) -> None: super().__init__(init_cfg=None) self.transformer_size = (input_size[0] // self.STRIDE, input_size[1] // self.STRIDE) self.layer_ids = layer_ids self.patch_embeds = nn.ModuleList() for idx, in_ch in enumerate(in_channels): if idx in layer_ids: if strides[idx] > self.STRIDE: patch_embed = PatchEmbed( conv_type=nn.ConvTranspose2d, in_channels=in_ch, embed_dims=embed_channels, kernel_size=strides[idx] // self.STRIDE, stride=strides[idx] // self.STRIDE, input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx]) ) else: patch_embed = PatchEmbed( in_channels=in_ch, embed_dims=embed_channels, kernel_size=self.STRIDE // strides[idx], stride=self.STRIDE // strides[idx], input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx]) ) self.patch_embeds.append(patch_embed) else: self.patch_embeds.append(nn.Identity()) if embedding_path is not None: assert embedding_path.startswith('sam_') embedding_ckpt = embedding_path.split('_', maxsplit=1)[1] path = checkpoint_dict[embedding_ckpt] state_dict = load_checkpoint_with_prefix(path, prefix='image_encoder') pos_embed = state_dict['pos_embed'] else: # For loading from checkpoint pos_embed = torch.zeros(1, input_size[0] // self.STRIDE, input_size[1] // self.STRIDE, embed_channels) self.register_buffer('pos_embed', pos_embed) self.level_encoding = nn.Embedding(len(layer_ids), embed_channels) depth = 5 global_attn_indexes = [4] window_size = 14 self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_channels, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, use_rel_pos=True, rel_pos_zero_init=True, window_size=window_size if i not in global_attn_indexes else 0, input_size=self.transformer_size, ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_channels, out_channels, kernel_size=1, bias=False, ), LayerNorm2d(out_channels), nn.Conv2d( out_channels, out_channels, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_channels), ) self.fix = fix if self.fix: self.train(mode=False) for name, param in self.named_parameters(): param.requires_grad = False if init_cfg is not None: assert init_cfg['type'] == 'Pretrained' checkpoint_path = init_cfg['checkpoint'] state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix']) self.load_state_dict(state_dict, strict=True) self._is_init = True def init_weights(self): normal_init(self.level_encoding, mean=0, std=1) def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: if not isinstance(mode, bool): raise ValueError("training mode is expected to be boolean") if self.fix: super().train(mode=False) else: super().train(mode=mode) return self def forward(self, inputs: Tuple[Tensor]) -> Tensor: input_embeddings = [] level_cnt = 0 for idx, feat in enumerate(inputs): if idx not in self.layer_ids: continue feat, size = self.patch_embeds[idx](feat) feat = feat.unflatten(1, size) feat = feat + self.level_encoding.weight[level_cnt] input_embeddings.append(feat) level_cnt += 1 feat = sum(input_embeddings) feat = feat + self.pos_embed for block in self.blocks: feat = block(feat) feat = feat.permute(0, 3, 1, 2).contiguous() feat = self.neck(feat) return feat