File size: 2,631 Bytes
495fe55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
""" Model schema in open_clip format for inference only. """
import math
from typing import Any, Optional, Dict
import torch
import torch.nn.functional as F
from torch import nn
from mobileclip.text_encoder import (
TextTransformer,
)
from .image_encoder import MCi
class CLIP(nn.Module):
"""Base class for multi-modal image-text data"""
def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None:
super().__init__()
self.output_dict = output_dict
self.projection_dim = cfg["embed_dim"]
if self.projection_dim is None:
raise ValueError("Please specify `embed_dim` in model config.")
self.image_encoder = MCi(
model_name=cfg["image_cfg"]["model_name"],
projection_dim=self.projection_dim,
)
self.text_encoder = TextTransformer(
cfg=cfg["text_cfg"], projection_dim=self.projection_dim
)
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
scale = self.logit_scale.exp()
scale = torch.clamp(scale, 0, max_scale)
return scale
def encode_image(self, image: torch.Tensor, normalize: bool = False):
image_encoder_out = self.image_encoder(image)
if isinstance(image_encoder_out, dict):
features = image_encoder_out["logits"]
else:
features = image_encoder_out
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text: torch.Tensor, normalize: bool = False):
text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
return F.normalize(text_features, dim=-1) if normalize else text_features
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
*args,
**kwargs
) -> Any:
image_embeddings = (
self.encode_image(image, normalize=True) if image is not None else None
)
text_embeddings = (
self.encode_text(text, normalize=True) if text is not None else None
)
if self.output_dict:
return {
"image_features": image_embeddings,
"text_features": text_embeddings,
"logit_scale": self._exponentiate_and_clip_logits(),
}
return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
|