|
|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
import torch.nn as nn |
|
from timm.models import create_model |
|
|
|
from mobileclip import models |
|
from mobileclip.modules.image.image_projection import GlobalPool2D |
|
|
|
|
|
class MCi(nn.Module): |
|
""" |
|
This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_ |
|
""" |
|
|
|
def __init__(self, model_name: str, *args, **kwargs) -> None: |
|
super().__init__() |
|
self.projection_dim = None |
|
if "projection_dim" in kwargs: |
|
self.projection_dim = kwargs.get("projection_dim") |
|
|
|
|
|
self.model = create_model(model_name, projection_dim=self.projection_dim) |
|
|
|
|
|
if self.projection_dim is not None: |
|
if hasattr(self.model, "head"): |
|
self.model.head = MCi._update_image_classifier( |
|
image_classifier=self.model.head, projection_dim=self.projection_dim |
|
) |
|
|
|
def forward(self, x: Any, *args, **kwargs) -> Any: |
|
"""A forward function of the model.""" |
|
x = self.model(x) |
|
return x |
|
|
|
@staticmethod |
|
def _get_in_feature_dimension(image_classifier: nn.Module) -> int: |
|
"""Return the input feature dimension to the image classification head.""" |
|
in_features = None |
|
if isinstance(image_classifier, nn.Sequential): |
|
|
|
|
|
|
|
for layer in image_classifier: |
|
if isinstance(layer, nn.Linear): |
|
in_features = layer.in_features |
|
break |
|
elif isinstance(image_classifier, nn.Linear): |
|
in_features = image_classifier.in_features |
|
|
|
if in_features is None: |
|
raise NotImplementedError( |
|
f"Cannot get input feature dimension of {image_classifier}." |
|
) |
|
return in_features |
|
|
|
@staticmethod |
|
def _update_image_classifier( |
|
image_classifier: nn.Module, projection_dim: int, *args, **kwargs |
|
) -> nn.Module: |
|
in_features = MCi._get_in_feature_dimension(image_classifier) |
|
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim) |
|
return new_img_classifier |
|
|