irotem98's picture
moondream_model_state_dict.pt
495fe55
raw
history blame
9.42 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Optional
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from mobileclip.modules.common.mobileone import MobileOneBlock
class ConvFFN(nn.Module):
"""Convolutional FFN Module."""
def __init__(
self,
in_channels: int,
context_size: int,
hidden_channels: Optional[int] = None,
out_channels: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
Args:
in_channels: Number of input channels.
context_size: Context size for 1D signals.
hidden_channels: Number of channels after expansion. Default: None
out_channels: Number of output channels. Default: None
act_layer: Activation layer. Default: ``GELU``
drop: Dropout rate. Default: ``0.0``.
"""
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.conv = nn.Sequential()
self.conv.add_module(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, int(context_size)),
padding=(0, int(context_size // 2)),
groups=in_channels,
bias=False,
),
)
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class RepMixer(nn.Module):
"""Reparameterizable token mixer.
For more details, please refer to our paper:
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
"""
def __init__(
self,
dim,
kernel_size=3,
use_layer_scale=True,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Module.
Args:
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
kernel_size: Kernel size for spatial mixing. Default: 3
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.inference_mode = inference_mode
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=self.dim,
out_channels=self.dim,
kernel_size=(1, self.kernel_size),
stride=1,
padding=(0, self.kernel_size // 2),
groups=self.dim,
bias=True,
)
else:
self.norm = MobileOneBlock(
dim,
dim,
(1, kernel_size),
padding=(0, kernel_size // 2),
groups=dim,
use_act=False,
use_scale_branch=False,
num_conv_branches=0,
)
self.mixer = MobileOneBlock(
dim,
dim,
(1, kernel_size),
padding=(0, kernel_size // 2),
groups=dim,
use_act=False,
)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
x = self.reparam_conv(x)
return x
else:
if self.use_layer_scale:
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
else:
x = x + self.mixer(x) - self.norm(x)
return x
def reparameterize(self) -> None:
"""Reparameterize mixer and norm into a single
convolutional layer for efficient inference.
"""
if self.inference_mode:
return
self.mixer.reparameterize()
self.norm.reparameterize()
if self.use_layer_scale:
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
)
b = torch.squeeze(self.layer_scale) * (
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
)
else:
w = (
self.mixer.id_tensor
+ self.mixer.reparam_conv.weight
- self.norm.reparam_conv.weight
)
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
self.reparam_conv = nn.Conv2d(
in_channels=self.dim,
out_channels=self.dim,
kernel_size=(1, self.kernel_size),
stride=1,
padding=(0, self.kernel_size // 2),
groups=self.dim,
bias=True,
)
self.reparam_conv.weight.data = w
self.reparam_conv.bias.data = b
for para in self.parameters():
para.detach_()
self.__delattr__("mixer")
self.__delattr__("norm")
if self.use_layer_scale:
self.__delattr__("layer_scale")
class RepMixerBlock(nn.Module):
"""Implementation of Metaformer block with RepMixer as token mixer.
For more details on Metaformer structure, please refer to:
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
"""
def __init__(
self,
dim: int,
kernel_size: int = 11,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
*args,
**kwargs,
):
"""Build RepMixer Block.
Args:
dim: Number of embedding dimensions.
kernel_size: Kernel size for repmixer. Default: 3
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super().__init__()
self.token_mixer = RepMixer(
dim,
kernel_size=kernel_size,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvFFN(
in_channels=dim,
context_size=kernel_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Drop Path
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
def forward(self, x, *args, **kwargs):
if x.dim() == 3:
# B, C, D --- where C is the context length
# Convert to B, D, C --- to match RepMixer impl.
x = x.permute(0, 2, 1)
x = torch.unsqueeze(x, dim=2)
else:
raise ValueError(
f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}"
)
if self.use_layer_scale:
x = self.token_mixer(x)
x = x + self.drop_path(self.layer_scale * self.convffn(x))
else:
x = self.token_mixer(x)
x = x + self.drop_path(self.convffn(x))
# Convert tensors back
x = x.squeeze(dim=2).permute(0, 2, 1)
return x