# # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All Rights Reserved. # import copy from functools import partial from typing import List, Tuple, Optional, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models import register_model from mobileclip.modules.common.mobileone import MobileOneBlock from mobileclip.modules.image.replknet import ReparamLargeKernelConv def _cfg(url="", **kwargs): return { "url": url, "num_classes": 1000, "input_size": (3, 256, 256), "pool_size": None, "crop_pct": 0.95, "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "classifier": "head", **kwargs, } default_cfgs = { "fastvit_t": _cfg(crop_pct=0.9), "fastvit_s": _cfg(crop_pct=0.9), "fastvit_m": _cfg(crop_pct=0.95), } def convolutional_stem( in_channels: int, out_channels: int, inference_mode: bool = False ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. Args: in_channels: Number of input channels. out_channels: Number of output channels. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` Returns: nn.Sequential object with stem elements. """ return nn.Sequential( MobileOneBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), ) class MHSA(nn.Module): """Multi-headed Self Attention module. Source modified from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ def __init__( self, dim: int, head_dim: int = 32, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. Args: dim: Number of embedding dimensions. head_dim: Number of hidden dimensions per head. Default: ``32`` qkv_bias: Use bias or not. Default: ``False`` attn_drop: Dropout rate for attention tensor. proj_drop: Dropout rate for projection tensor. """ super().__init__() assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim self.num_heads = dim // head_dim self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = x.shape B, C, H, W = shape N = H * W if len(shape) == 4: x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # trick here to make q@k.t more stable attn = (q * self.scale) @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) if len(shape) == 4: x = x.transpose(-2, -1).reshape(B, C, H, W) return x class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( self, patch_size: int, stride: int, in_channels: int, embed_dim: int, inference_mode: bool = False, use_se: bool = False, ) -> None: """Build patch embedding layer. Args: patch_size: Patch size for embedding computation. stride: Stride for convolutional embedding layer. in_channels: Number of channels of input tensor. embed_dim: Number of embedding dimensions. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` use_se: If ``True`` SE block will be used. """ super().__init__() block = list() block.append( ReparamLargeKernelConv( in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=stride, groups=in_channels, small_kernel=3, inference_mode=inference_mode, use_se=use_se, ) ) block.append( MobileOneBlock( in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ) ) self.proj = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x class RepMixer(nn.Module): """Reparameterizable token mixer. For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ """ def __init__( self, dim, kernel_size=3, use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): """Build RepMixer Module. Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super().__init__() self.dim = dim self.kernel_size = kernel_size self.inference_mode = inference_mode if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) else: self.norm = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ if self.inference_mode: return self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b for para in self.parameters(): para.detach_() self.__delattr__("mixer") self.__delattr__("norm") if self.use_layer_scale: self.__delattr__("layer_scale") class ConvFFN(nn.Module): """Convolutional FFN Module.""" def __init__( self, in_channels: int, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, act_layer: nn.Module = nn.GELU, drop: float = 0.0, ) -> None: """Build convolutional FFN module. Args: in_channels: Number of input channels. hidden_channels: Number of channels after expansion. Default: None out_channels: Number of output channels. Default: None act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.conv = nn.Sequential() self.conv.add_module( "conv", nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, padding=3, groups=in_channels, bias=False, ), ) self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class RepCPE(nn.Module): """Implementation of conditional positional encoding. For more details refer to paper: `Conditional Positional Encodings for Vision Transformers `_ In our implementation, we can reparameterize this module to eliminate a skip connection. """ def __init__( self, in_channels: int, embed_dim: int = 768, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding Args: in_channels: Number of input channels. embed_dim: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super(RepCPE, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( f'"spatial_shape" must by a sequence or int, ' f"get {type(spatial_shape)} instead." ) assert len(spatial_shape) == 2, ( f'Length of "spatial_shape" should be 2, ' f"got {len(spatial_shape)} instead." ) self.spatial_shape = spatial_shape self.embed_dim = embed_dim self.in_channels = in_channels self.groups = embed_dim if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) else: self.pe = nn.Conv2d( in_channels, embed_dim, spatial_shape, 1, int(spatial_shape[0] // 2), bias=True, groups=embed_dim, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: x = self.pe(x) + x return x def reparameterize(self) -> None: # Build equivalent Id tensor input_dim = self.in_channels // self.groups kernel_value = torch.zeros( ( self.in_channels, input_dim, self.spatial_shape[0], self.spatial_shape[1], ), dtype=self.pe.weight.dtype, device=self.pe.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, self.spatial_shape[0] // 2, self.spatial_shape[1] // 2, ] = 1 id_tensor = kernel_value # Reparameterize Id tensor and conv w_final = id_tensor + self.pe.weight b_final = self.pe.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final for para in self.parameters(): para.detach_() self.__delattr__("pe") class RepMixerBlock(nn.Module): """Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): """Build RepMixer Block. Args: dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super().__init__() self.token_mixer = RepMixer( dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop Path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) return x class AttentionBlock(nn.Module): """Implementation of metaformer block with MHSA as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): """Build Attention Block. Args: dim: Number of embedding dimensions. mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ super().__init__() self.norm = norm_layer(dim) self.token_mixer = MHSA(dim=dim) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) else: x = x + self.drop_path(self.token_mixer(self.norm(x))) x = x + self.drop_path(self.convffn(x)) return x def basic_blocks( dim: int, block_index: int, num_blocks: List[int], token_mixer_type: str, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode=False, ) -> nn.Sequential: """Build FastViT blocks within a stage. Args: dim: Number of embedding dimensions. block_index: block index. num_blocks: List containing number of blocks per stage. token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio. act_layer: Activation layer. norm_layer: Normalization layer. drop_rate: Dropout rate. drop_path_rate: Drop path rate. use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode. Returns: nn.Sequential object of all the blocks within the stage. """ blocks = [] for block_idx in range(num_blocks[block_index]): block_dpr = ( drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1) ) if token_mixer_type == "repmixer": blocks.append( RepMixerBlock( dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio, act_layer=act_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) ) elif token_mixer_type == "attention": blocks.append( AttentionBlock( dim, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) ) else: raise ValueError( "Token mixer type: {} not supported".format(token_mixer_type) ) blocks = nn.Sequential(*blocks) return blocks class FastViT(nn.Module): """ This class implements `FastViT architecture `_ """ def __init__( self, layers, token_mixers: Tuple[str, ...], embed_dims=None, mlp_ratios=None, downsamples=None, se_downsamples=None, repmixer_kernel_size=3, norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, num_classes=1000, pos_embs=None, down_patch_size=7, down_stride=2, drop_rate=0.0, drop_path_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, init_cfg=None, pretrained=None, cls_ratio=2.0, inference_mode=False, **kwargs, ) -> None: super().__init__() self.num_classes = num_classes if pos_embs is None: pos_embs = [None] * len(layers) if se_downsamples is None: se_downsamples = [False] * len(layers) # Convolutional stem self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) # Build the main stages of the network architecture network = [] for i in range(len(layers)): # Add position embeddings if requested if pos_embs[i] is not None: network.append( pos_embs[i]( embed_dims[i], embed_dims[i], inference_mode=inference_mode ) ) stage = basic_blocks( embed_dims[i], i, layers, token_mixer_type=token_mixers[i], kernel_size=repmixer_kernel_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) network.append(stage) if i >= len(layers) - 1: break # Patch merging/downsampling between stages. if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, in_channels=embed_dims[i], embed_dim=embed_dims[i + 1], inference_mode=inference_mode, use_se=se_downsamples[i + 1], ) ) self.network = nn.ModuleList(network) # Classifier head self.conv_exp = MobileOneBlock( in_channels=embed_dims[-1], out_channels=int(embed_dims[-1] * cls_ratio), kernel_size=3, stride=1, padding=1, groups=embed_dims[-1], inference_mode=inference_mode, use_se=True, num_conv_branches=1, ) self.head = ( nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self.cls_init_weights) self.init_cfg = copy.deepcopy(init_cfg) def cls_init_weights(self, m: nn.Module) -> None: """Init. for classification""" if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) return x def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification return x def forward(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.forward_embeddings(x) # through backbone x = self.forward_tokens(x) # for image classification x = self.conv_exp(x) cls_out = self.head(x) return cls_out @register_model def mci0(pretrained=False, **kwargs): """Instantiate MCi0 model variant.""" layers = [2, 6, 10, 2] embed_dims = [64, 128, 256, 512] mlp_ratios = [3, 3, 3, 3] downsamples = [True, True, True, True] se_downsamples = [False, False, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, se_downsamples=se_downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_s"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model def mci1(pretrained=False, **kwargs): """Instantiate MCi1 model variant.""" layers = [4, 12, 20, 4] embed_dims = [64, 128, 256, 512] mlp_ratios = [3, 3, 3, 3] downsamples = [True, True, True, True] se_downsamples = [False, False, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, se_downsamples=se_downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_s"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model def mci2(pretrained=False, **kwargs): """Instantiate MCi2 model variant.""" layers = [4, 12, 24, 4] embed_dims = [80, 160, 320, 640] mlp_ratios = [3, 3, 3, 3] downsamples = [True, True, True, True] se_downsamples = [False, False, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, se_downsamples=se_downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_m"] if pretrained: raise ValueError("Functionality not implemented.") return model