import torch import torch.nn as nn from transformers import CLIPImageProcessor try: from imagebind.models import imagebind_model from imagebind.models.imagebind_model import ModalityType from imagebind.data import load_and_transform_audio_data except ImportError: pass class ImageBindWrapper(nn.Module): def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = select_layer self.select_feature = select_feature if not delay_load: self.load_model() def load_model(self): self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) for p in self.vision_tower.parameters(): p.requires_grad = False self.vision_tower.eval() self.is_loaded = True def train(self, mode=True): self.training = mode if self.is_loaded: self.vision_tower.eval() @torch.no_grad() def forward(self, x): if type(x) == dict: if x["audios"] is not None: inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} embeddings = self.vision_tower(inputs) audio_embedding = embeddings[ModalityType.AUDIO] return audio_embedding.unsqueeze(1) else: inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} embeddings = self.vision_tower(inputs) vision_embedding = embeddings[ModalityType.VISION] if vision_embedding.ndim == 2: return vision_embedding.unsqueeze(1) if vision_embedding.shape[1] == 257: return vision_embedding[:, 1:] raise ValueError(f"Unexpected shape: {vision_embedding.shape}") @property def dummy_feature(self): return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.modality_preprocessors.vision.cls_token.dtype @property def device(self): return self.vision_tower.modality_preprocessors.vision.cls_token.device @property def hidden_size(self): return 1024