|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from timm.models.layers import SqueezeExcite |
|
|
|
__all__ = ["ReparamLargeKernelConv"] |
|
|
|
|
|
class ReparamLargeKernelConv(nn.Module): |
|
"""Building Block of RepLKNet |
|
|
|
This class defines overparameterized large kernel conv block |
|
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_ |
|
|
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int, |
|
groups: int, |
|
small_kernel: int, |
|
inference_mode: bool = False, |
|
use_se: bool = False, |
|
activation: nn.Module = nn.GELU(), |
|
) -> None: |
|
"""Construct a ReparamLargeKernelConv module. |
|
|
|
Args: |
|
in_channels: Number of input channels. |
|
out_channels: Number of output channels. |
|
kernel_size: Kernel size of the large kernel conv branch. |
|
stride: Stride size. Default: 1 |
|
groups: Group number. Default: 1 |
|
small_kernel: Kernel size of small kernel conv branch. |
|
inference_mode: If True, instantiates model in inference mode. Default: ``False`` |
|
activation: Activation module. Default: ``nn.GELU`` |
|
""" |
|
super(ReparamLargeKernelConv, self).__init__() |
|
|
|
self.stride = stride |
|
self.groups = groups |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.activation = activation |
|
|
|
self.kernel_size = kernel_size |
|
self.small_kernel = small_kernel |
|
self.padding = kernel_size // 2 |
|
|
|
|
|
if use_se: |
|
self.se = SqueezeExcite(out_channels, rd_ratio=0.25) |
|
else: |
|
self.se = nn.Identity() |
|
|
|
if inference_mode: |
|
self.lkb_reparam = nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=self.padding, |
|
dilation=1, |
|
groups=groups, |
|
bias=True, |
|
) |
|
else: |
|
self.lkb_origin = self._conv_bn( |
|
kernel_size=kernel_size, padding=self.padding |
|
) |
|
if small_kernel is not None: |
|
assert ( |
|
small_kernel <= kernel_size |
|
), "The kernel size for re-param cannot be larger than the large kernel!" |
|
self.small_conv = self._conv_bn( |
|
kernel_size=small_kernel, padding=small_kernel // 2 |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Apply forward pass.""" |
|
if hasattr(self, "lkb_reparam"): |
|
out = self.lkb_reparam(x) |
|
else: |
|
out = self.lkb_origin(x) |
|
if hasattr(self, "small_conv"): |
|
out += self.small_conv(x) |
|
|
|
return self.activation(self.se(out)) |
|
|
|
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to obtain re-parameterized kernel and bias. |
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing branches. |
|
""" |
|
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) |
|
if hasattr(self, "small_conv"): |
|
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) |
|
eq_b += small_b |
|
eq_k += nn.functional.pad( |
|
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 |
|
) |
|
return eq_k, eq_b |
|
|
|
def reparameterize(self) -> None: |
|
""" |
|
Following works like `RepVGG: Making VGG-style ConvNets Great Again` - |
|
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched |
|
architecture used at training time to obtain a plain CNN-like structure |
|
for inference. |
|
""" |
|
eq_k, eq_b = self.get_kernel_bias() |
|
self.lkb_reparam = nn.Conv2d( |
|
in_channels=self.in_channels, |
|
out_channels=self.out_channels, |
|
kernel_size=self.kernel_size, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.lkb_origin.conv.dilation, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
|
|
self.lkb_reparam.weight.data = eq_k |
|
self.lkb_reparam.bias.data = eq_b |
|
self.__delattr__("lkb_origin") |
|
if hasattr(self, "small_conv"): |
|
self.__delattr__("small_conv") |
|
|
|
@staticmethod |
|
def _fuse_bn( |
|
conv: torch.Tensor, bn: nn.BatchNorm2d |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to fuse batchnorm layer with conv layer. |
|
|
|
Args: |
|
conv: Convolutional kernel weights. |
|
bn: Batchnorm 2d layer. |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing batchnorm. |
|
""" |
|
kernel = conv.weight |
|
running_mean = bn.running_mean |
|
running_var = bn.running_var |
|
gamma = bn.weight |
|
beta = bn.bias |
|
eps = bn.eps |
|
std = (running_var + eps).sqrt() |
|
t = (gamma / std).reshape(-1, 1, 1, 1) |
|
return kernel * t, beta - running_mean * gamma / std |
|
|
|
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: |
|
"""Helper method to construct conv-batchnorm layers. |
|
|
|
Args: |
|
kernel_size: Size of the convolution kernel. |
|
padding: Zero-padding size. |
|
|
|
Returns: |
|
A nn.Sequential Conv-BN module. |
|
""" |
|
mod_list = nn.Sequential() |
|
mod_list.add_module( |
|
"conv", |
|
nn.Conv2d( |
|
in_channels=self.in_channels, |
|
out_channels=self.out_channels, |
|
kernel_size=kernel_size, |
|
stride=self.stride, |
|
padding=padding, |
|
groups=self.groups, |
|
bias=False, |
|
), |
|
) |
|
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) |
|
return mod_list |
|
|