Navyabhat's picture
Upload 76 files
f776adf verified
raw
history blame contribute delete
No virus
3.04 kB
from abc import ABC
import torch
import torch.nn as nn
from transformers import CLIPPreTrainedModel, CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
from inference.model.language_model.configuration_llava_phi import LlavaPhiVisionConfig
class CLIPVisionTower(CLIPPreTrainedModel):
config_class = LlavaPhiVisionConfig
def __init__(self, config):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[
self.config.mm_vision_select_layer
]
if self.config.mm_vision_select_feature == "patch":
image_features = image_features[:, 1:]
elif self.config.mm_vision_select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature: {self.config.mm_vision_select_feature}"
)
return image_features
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_model(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_model(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return list(self.vision_model.parameters())[0].dtype
@property
def device(self):
return list(self.vision_model.parameters())[0].device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
if __name__ == "__main__":
clip_config = CLIPVisionConfig.from_pretrained(
"/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
)
print("################ clip_config ##############")
print(clip_config)
phi_vis_config = LlavaPhiVisionConfig(**clip_config.to_dict())
print("################ phi_vis_config ##############")
print(phi_vis_config)
model = CLIPVisionTower(clip_config)
# print(list(model.vision_model.parameters())[0].dtype)