File size: 16,670 Bytes
495fe55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
"""
Implementation of the following modules is borrowed from ml-cvnets repo:
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py
Please see ACKNOWLEDGEMENTS for license details.
"""
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor, nn
from timm.models import register_model
from mobileclip.modules.common.transformer import (
PositionalEmbedding,
TransformerEncoder,
get_normalization_layer,
)
from mobileclip.modules.image.image_projection import SimpleImageProjectionHead
from mobileclip import logger
class ConvNormAct(nn.Module):
"""
Applies an N-dimensional convolution over an input.
Args:
cfg: Model configuration.
in_channels: :math:`C_{out}` from an expected output of size
:math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
out_channels: :math:`C_{out}` from an expected output of size
:math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
dilation: Dilation rate for convolution. An integer, or tuple of length ``N``.
Default: ``1``.
padding: Padding for convolution. An integer, or tuple of length ``N``.
If not specified, padding is automatically computed based on kernel size and
dilation range. Default : ``None`` (equivalent to ``[
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
groups: Number of groups in convolution. Default: ``1``.
bias: Use bias. Default: ``False``.
padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular').
Default: ``zeros``.
use_norm: Use normalization layer after convolution. Default: ``True``.
use_act: Use activation layer after convolution (or convolution and normalization).
Default: ``True``.
norm_layer: If not None, the provided normalization layer object will be used.
Otherwise, a normalization object will be created based on config
``model.normalization.*`` opts.
act_layer: If not None, the provided activation function will be used.
Otherwise, an activation function will be created based on config
``model.activation.*`` opts.
Shape:
- Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
- Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
.. note::
For depth-wise convolution, `groups=C_{in}=C_{out}`.
"""
def __init__(
self,
cfg: Dict,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
padding: Optional[Union[int, Tuple[int, ...]]] = None,
groups: int = 1,
bias: bool = False,
padding_mode: str = "zeros",
use_norm: bool = True,
use_act: bool = True,
norm_layer: Optional[nn.Module] = None,
act_layer: Optional[nn.Module] = None,
*args,
**kwargs,
) -> None:
super().__init__()
self.ndim = 2
if norm_layer is None and use_norm:
norm_type = cfg.get("normalization", "batch_norm")
if norm_type == "batch_norm":
norm_layer = nn.BatchNorm2d(
num_features=out_channels,
momentum=cfg.get("momentum", 0.1),
)
else:
norm_layer = get_normalization_layer(
num_features=out_channels, norm_type=norm_type
)
elif norm_layer is not None and use_norm:
logger.error(
f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided."
)
if act_layer is None and use_act:
act_layer = nn.GELU() # Default to GELU
elif act_layer is not None and use_act:
logger.error(
f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided."
)
if (
use_norm
and any(param[0] == "bias" for param in norm_layer.named_parameters())
and bias
):
assert (
not bias
), "Do not use bias when using normalization layers with bias."
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * self.ndim
if isinstance(stride, int):
stride = (stride,) * self.ndim
if isinstance(dilation, int):
dilation = (dilation,) * self.ndim
assert isinstance(kernel_size, Tuple)
assert isinstance(stride, Tuple)
assert isinstance(dilation, Tuple)
if padding is None:
padding = (
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim)
)
if in_channels % groups != 0:
logger.error(
"Input channels are not divisible by groups. {}%{} != 0 ".format(
in_channels, groups
)
)
if out_channels % groups != 0:
logger.error(
"Output channels are not divisible by groups. {}%{} != 0 ".format(
out_channels, groups
)
)
block = nn.Sequential()
conv_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size, # type: ignore
stride=stride, # type: ignore
padding=padding,
dilation=dilation, # type: ignore
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
block.add_module(name="conv", module=conv_layer)
self.norm_name = None
if use_norm:
block.add_module(name="norm", module=norm_layer)
self.norm_name = norm_layer.__class__.__name__
self.act_name = None
if use_act:
block.add_module(name="act", module=act_layer)
self.act_name = act_layer.__class__.__name__
self.block = block
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.groups = groups
self.kernel_size = conv_layer.kernel_size
self.bias = bias
self.dilation = dilation
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
class VisionTransformer(nn.Module):
"""
This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model implementation
is inspired from `Early Convolutions Help Transformers See Better <https://arxiv.org/abs/2106.14881>`_
.. note::
Our implementation is different from the original implementation in two ways:
1. Kernel size is odd.
2. Our positional encoding implementation allows us to use ViT with any multiple input scales
3. We do not use StochasticDepth
4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
"""
def __init__(self, cfg, *args, **kwargs) -> None:
super().__init__()
image_channels = 3
num_classes = cfg.get("n_classes", 1000)
self.projection_dim = None
if "projection_dim" in kwargs:
self.projection_dim = kwargs.get("projection_dim")
kernel_sizes_conv_stem = [4, 2, 2]
strides_conv_stem = [4, 2, 2]
# Typically, in the ImageNet dataset, we use 224x224 as a resolution.
# For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
# Therefore, total number of embeddings along width and height are (224 / 16)^2
num_embeddings = (224 // 16) ** 2
embed_dim = cfg["embed_dim"]
ffn_dim = cfg["embed_dim"] * 4
pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
n_transformer_layers = cfg["n_transformer_layers"]
num_heads = cfg["n_attn_heads"]
attn_dropout = cfg.get("attn_dropout", 0.0)
dropout = cfg.get("dropout", 0.0)
ffn_dropout = cfg.get("ffn_dropout", 0.0)
norm_layer = cfg.get("norm_layer", "layer_norm")
conv_stem_proj_dim = max(32, embed_dim // 4)
patch_emb = [
ConvNormAct(
cfg=cfg,
in_channels=image_channels,
out_channels=conv_stem_proj_dim,
kernel_size=kernel_sizes_conv_stem[0],
stride=strides_conv_stem[0],
bias=False,
use_norm=True,
use_act=True,
),
ConvNormAct(
cfg=cfg,
in_channels=conv_stem_proj_dim,
out_channels=conv_stem_proj_dim,
kernel_size=kernel_sizes_conv_stem[1],
stride=strides_conv_stem[1],
bias=False,
use_norm=True,
use_act=True,
),
ConvNormAct(
cfg=cfg,
in_channels=conv_stem_proj_dim,
out_channels=embed_dim,
kernel_size=kernel_sizes_conv_stem[2],
stride=strides_conv_stem[2],
bias=True,
use_norm=False,
use_act=False,
),
]
self.patch_emb = nn.Sequential(*patch_emb)
use_cls_token = not cfg.get("no_cls_token", False)
stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
per_layer_stochastic_drop_rate = [
round(x, 3)
for x in np.linspace(0, stochastic_dropout, n_transformer_layers)
]
transformer_blocks = [
TransformerEncoder(
embed_dim=embed_dim,
ffn_latent_dim=ffn_dim,
num_heads=num_heads,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
transformer_norm_layer=norm_layer,
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
)
for layer_idx in range(n_transformer_layers)
]
self.post_transformer_norm = get_normalization_layer(
num_features=embed_dim, norm_type=norm_layer
)
self.transformer = nn.Sequential(*transformer_blocks)
if self.projection_dim is None:
self.classifier = nn.Linear(embed_dim, num_classes)
else:
self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
if use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
else:
self.cls_token = None
self.pos_embed = PositionalEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embed_dim,
padding_idx=None,
interpolation_mode="bilinear",
)
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
def extract_patch_embeddings(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
# input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
batch_size = x.shape[0]
# [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
patch_emb = self.patch_emb(x)
n_h, n_w = patch_emb.shape[-2:]
# [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
patch_emb = patch_emb.flatten(2)
# [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
patch_emb = patch_emb.transpose(1, 2).contiguous()
n_patches = patch_emb.shape[1]
# we resize the positional encodings dynamically.
pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
# add positional encodings
patch_emb = pos_emb + patch_emb
# add classification token
if self.cls_token is not None:
# [1, 1, emb_dim] --> [Batch, 1, emb_dim]
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
# dropout
patch_emb = self.emb_dropout(patch_emb)
return patch_emb, (n_h, n_w)
def _features_from_transformer(
self, x: Tensor, *args, **kwargs
) -> Tuple[Tensor, Tuple[int, int]]:
# this function extract patch embeddings and then apply transformer module to learn
# inter-patch representations
# [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
# and embed_dim is feature dim
x, (n_h, n_w) = self.extract_patch_embeddings(x)
for layer in self.transformer:
x = layer(x)
x = self.post_transformer_norm(x)
return x, (n_h, n_w)
def extract_features(
self, x: Tensor, *args, **kwargs
) -> Tuple[Tensor, Optional[Tensor]]:
# The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
# and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
return_image_embeddings = kwargs.get("return_image_embeddings", False)
# [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
# here, B is batch size, C is input channels
# H and W are input height and width
# N is the number of pixels (or tokens) after processing input with conv stem and reshaping
# We add +1 for cls token (if applicable)
# embed_dim --> embedding dimension
x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
if self.cls_token is not None:
# [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
cls_embedding, image_embedding = torch.split(
x, split_size_or_sections=[1, x.shape[1] - 1], dim=1
)
cls_embedding = cls_embedding.squeeze(1)
else:
# [B, N, embed_dim] -> [B, embed_dim]
cls_embedding = torch.mean(x, dim=1)
# [B, N, embed_dim]
image_embedding = x
if return_image_embeddings:
# reshape image embedding to 4-D tensor
# [B, N, C] --> [B, C, N]
image_embedding = image_embedding.transpose(1, 2).contiguous()
image_embedding = image_embedding.reshape(
image_embedding.shape[0], -1, n_h, n_w
)
return cls_embedding, image_embedding
else:
return cls_embedding, None
def forward_classifier(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
# classify based on CLS token
cls_embedding = self.classifier(cls_embedding)
return cls_embedding, image_embedding
def forward(self, x: Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
# In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
# To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
if kwargs.get("return_image_embeddings", False):
out_dict = dict()
prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
out_dict.update({"logits": prediction})
if image_embedding is not None:
out_dict.update({"image_embeddings": image_embedding})
return out_dict
else:
prediction, _ = self.forward_classifier(x, *args, **kwargs)
return prediction
@register_model
def vit_b16(pretrained=False, **kwargs):
# Vision transformer config
cfg = {
"norm_layer": "layer_norm_fp32",
"act_layer": "gelu",
"embed_dim": 768,
"n_transformer_layers": 12,
"n_attn_heads": 12,
}
model = VisionTransformer(cfg=cfg, **kwargs)
if pretrained:
raise ValueError("Functionality not implemented.")
return model
|