edge_vlm / mobileclip /image_encoder.py
irotem98's picture
moondream_model_state_dict.pt
495fe55
raw
history blame contribute delete
No virus
2.46 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Any
import torch.nn as nn
from timm.models import create_model
from mobileclip import models # Added to register models
from mobileclip.modules.image.image_projection import GlobalPool2D
class MCi(nn.Module):
"""
This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
"""
def __init__(self, model_name: str, *args, **kwargs) -> None:
super().__init__()
self.projection_dim = None
if "projection_dim" in kwargs:
self.projection_dim = kwargs.get("projection_dim")
# Create model
self.model = create_model(model_name, projection_dim=self.projection_dim)
# Build out projection head.
if self.projection_dim is not None:
if hasattr(self.model, "head"):
self.model.head = MCi._update_image_classifier(
image_classifier=self.model.head, projection_dim=self.projection_dim
)
def forward(self, x: Any, *args, **kwargs) -> Any:
"""A forward function of the model."""
x = self.model(x)
return x
@staticmethod
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
"""Return the input feature dimension to the image classification head."""
in_features = None
if isinstance(image_classifier, nn.Sequential):
# Classifier that uses nn.Sequential usually has global pooling and
# multiple linear layers. Find the first linear layer and get its
# in_features
for layer in image_classifier:
if isinstance(layer, nn.Linear):
in_features = layer.in_features
break
elif isinstance(image_classifier, nn.Linear):
in_features = image_classifier.in_features
if in_features is None:
raise NotImplementedError(
f"Cannot get input feature dimension of {image_classifier}."
)
return in_features
@staticmethod
def _update_image_classifier(
image_classifier: nn.Module, projection_dim: int, *args, **kwargs
) -> nn.Module:
in_features = MCi._get_in_feature_dimension(image_classifier)
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
return new_img_classifier