|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from mobileclip import logger |
|
|
|
|
|
class GlobalPool(nn.Module): |
|
""" |
|
This layers applies global pooling over a 4D or 5D input tensor |
|
|
|
Args: |
|
pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean` |
|
keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False` |
|
|
|
Shape: |
|
- Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` |
|
- Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)` |
|
""" |
|
|
|
pool_types = ["mean", "rms", "abs"] |
|
|
|
def __init__( |
|
self, |
|
pool_type: Optional[str] = "mean", |
|
keep_dim: Optional[bool] = False, |
|
*args, |
|
**kwargs |
|
) -> None: |
|
super().__init__() |
|
if pool_type not in self.pool_types: |
|
logger.error( |
|
"Supported pool types are: {}. Got {}".format( |
|
self.pool_types, pool_type |
|
) |
|
) |
|
self.pool_type = pool_type |
|
self.keep_dim = keep_dim |
|
|
|
def _global_pool(self, x: Tensor, dims: List): |
|
if self.pool_type == "rms": |
|
x = x**2 |
|
x = torch.mean(x, dim=dims, keepdim=self.keep_dim) |
|
x = x**-0.5 |
|
elif self.pool_type == "abs": |
|
x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim) |
|
else: |
|
|
|
|
|
x = torch.mean(x, dim=dims, keepdim=self.keep_dim) |
|
return x |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
if x.dim() == 4: |
|
dims = [-2, -1] |
|
elif x.dim() == 5: |
|
dims = [-3, -2, -1] |
|
else: |
|
raise NotImplementedError("Currently 2D and 3D global pooling supported") |
|
return self._global_pool(x, dims=dims) |
|
|
|
|
|
class GlobalPool2D(nn.Module): |
|
"""This class implements global pooling with linear projection.""" |
|
|
|
def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None: |
|
super().__init__() |
|
scale = in_dim**-0.5 |
|
self.pool = GlobalPool(pool_type="mean", keep_dim=False) |
|
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) |
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: |
|
|
|
assert ( |
|
x.dim() == 4 |
|
), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format( |
|
x.shape |
|
) |
|
|
|
|
|
x = self.pool(x) |
|
|
|
x = x @ self.proj |
|
return x |
|
|
|
|
|
class SimpleImageProjectionHead(nn.Module): |
|
"""This class implements linear projection head.""" |
|
|
|
def __init__(self, in_dim: int, out_dim: int) -> None: |
|
super().__init__() |
|
scale = in_dim**-0.5 |
|
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) |
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: |
|
|
|
assert ( |
|
x.dim() == 2 |
|
), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape) |
|
|
|
|
|
x = x @ self.proj |
|
return x |
|
|