# # For acknowledgement see accompanying ACKNOWLEDGEMENTS file. # Copyright (C) 2024 Apple Inc. All rights reserved. # from typing import Tuple import torch import torch.nn as nn from timm.models.layers import SqueezeExcite __all__ = ["ReparamLargeKernelConv"] class ReparamLargeKernelConv(nn.Module): """Building Block of RepLKNet This class defines overparameterized large kernel conv block introduced in `RepLKNet `_ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, small_kernel: int, inference_mode: bool = False, use_se: bool = False, activation: nn.Module = nn.GELU(), ) -> None: """Construct a ReparamLargeKernelConv module. Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Kernel size of the large kernel conv branch. stride: Stride size. Default: 1 groups: Group number. Default: 1 small_kernel: Kernel size of small kernel conv branch. inference_mode: If True, instantiates model in inference mode. Default: ``False`` activation: Activation module. Default: ``nn.GELU`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride self.groups = groups self.in_channels = in_channels self.out_channels = out_channels self.activation = activation self.kernel_size = kernel_size self.small_kernel = small_kernel self.padding = kernel_size // 2 # Check if SE is requested if use_se: self.se = SqueezeExcite(out_channels, rd_ratio=0.25) else: self.se = nn.Identity() if inference_mode: self.lkb_reparam = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=self.padding, dilation=1, groups=groups, bias=True, ) else: self.lkb_origin = self._conv_bn( kernel_size=kernel_size, padding=self.padding ) if small_kernel is not None: assert ( small_kernel <= kernel_size ), "The kernel size for re-param cannot be larger than the large kernel!" self.small_conv = self._conv_bn( kernel_size=small_kernel, padding=small_kernel // 2 ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" if hasattr(self, "lkb_reparam"): out = self.lkb_reparam(x) else: out = self.lkb_origin(x) if hasattr(self, "small_conv"): out += self.small_conv(x) return self.activation(self.se(out)) def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch Returns: Tuple of (kernel, bias) after fusing branches. """ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) if hasattr(self, "small_conv"): small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) eq_b += small_b eq_k += nn.functional.pad( small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 ) return eq_k, eq_b def reparameterize(self) -> None: """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ eq_k, eq_b = self.get_kernel_bias() self.lkb_reparam = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.lkb_origin.conv.dilation, groups=self.groups, bias=True, ) self.lkb_reparam.weight.data = eq_k self.lkb_reparam.bias.data = eq_b self.__delattr__("lkb_origin") if hasattr(self, "small_conv"): self.__delattr__("small_conv") @staticmethod def _fuse_bn( conv: torch.Tensor, bn: nn.BatchNorm2d ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with conv layer. Args: conv: Convolutional kernel weights. bn: Batchnorm 2d layer. Returns: Tuple of (kernel, bias) after fusing batchnorm. """ kernel = conv.weight running_mean = bn.running_mean running_var = bn.running_var gamma = bn.weight beta = bn.bias eps = bn.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: """Helper method to construct conv-batchnorm layers. Args: kernel_size: Size of the convolution kernel. padding: Zero-padding size. Returns: A nn.Sequential Conv-BN module. """ mod_list = nn.Sequential() mod_list.add_module( "conv", nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=kernel_size, stride=self.stride, padding=padding, groups=self.groups, bias=False, ), ) mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) return mod_list