# # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All Rights Reserved. # from typing import Optional import torch import torch.nn as nn from timm.models.layers import DropPath, trunc_normal_ from mobileclip.modules.common.mobileone import MobileOneBlock class ConvFFN(nn.Module): """Convolutional FFN Module.""" def __init__( self, in_channels: int, context_size: int, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, act_layer: nn.Module = nn.GELU, drop: float = 0.0, ) -> None: """Build convolutional FFN module. Args: in_channels: Number of input channels. context_size: Context size for 1D signals. hidden_channels: Number of channels after expansion. Default: None out_channels: Number of output channels. Default: None act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.conv = nn.Sequential() self.conv.add_module( "conv", nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(1, int(context_size)), padding=(0, int(context_size // 2)), groups=in_channels, bias=False, ), ) self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class RepMixer(nn.Module): """Reparameterizable token mixer. For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ """ def __init__( self, dim, kernel_size=3, use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): """Build RepMixer Module. Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super().__init__() self.dim = dim self.kernel_size = kernel_size self.inference_mode = inference_mode if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=(1, self.kernel_size), stride=1, padding=(0, self.kernel_size // 2), groups=self.dim, bias=True, ) else: self.norm = MobileOneBlock( dim, dim, (1, kernel_size), padding=(0, kernel_size // 2), groups=dim, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( dim, dim, (1, kernel_size), padding=(0, kernel_size // 2), groups=dim, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ if self.inference_mode: return self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=(1, self.kernel_size), stride=1, padding=(0, self.kernel_size // 2), groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b for para in self.parameters(): para.detach_() self.__delattr__("mixer") self.__delattr__("norm") if self.use_layer_scale: self.__delattr__("layer_scale") class RepMixerBlock(nn.Module): """Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, kernel_size: int = 11, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, *args, **kwargs, ): """Build RepMixer Block. Args: dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super().__init__() self.token_mixer = RepMixer( dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, context_size=kernel_size, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop Path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x, *args, **kwargs): if x.dim() == 3: # B, C, D --- where C is the context length # Convert to B, D, C --- to match RepMixer impl. x = x.permute(0, 2, 1) x = torch.unsqueeze(x, dim=2) else: raise ValueError( f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}" ) if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) # Convert tensors back x = x.squeeze(dim=2).permute(0, 2, 1) return x