|
import torch
|
|
import torch.nn as nn
|
|
from transformers import CLIPImageProcessor
|
|
from llava.utils import rank0_print
|
|
|
|
try:
|
|
import open_clip
|
|
import torchvision
|
|
from open_clip.transformer import _expand_token
|
|
except ImportError:
|
|
print("OpenCLIP not installed")
|
|
open_clip = None
|
|
|
|
HIDDEN_SIZE_DICT = {
|
|
"ViT-H-14-378-quickgelu": 1280,
|
|
}
|
|
|
|
|
|
class OpenCLIPVisionTower(nn.Module):
|
|
def __init__(self, vision_tower, args, delay_load=False):
|
|
super().__init__()
|
|
|
|
self.is_loaded = False
|
|
self.model_name = vision_tower.replace("open_clip_hub:", "")
|
|
self.pretrained = args.vision_tower_pretrained
|
|
self.select_layer = args.mm_vision_select_layer
|
|
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
|
|
|
if not delay_load:
|
|
rank0_print(f"Loading vision tower: {vision_tower}")
|
|
self.load_model()
|
|
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
|
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
|
self.load_model()
|
|
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
|
self.load_model()
|
|
|
|
def load_model(self, device_map="auto"):
|
|
rank0_print(f"Loading OpenCLIP model: {self.model_name}")
|
|
rank0_print(f"Pretrained: {self.pretrained}")
|
|
vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda")
|
|
|
|
resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
|
|
normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
|
|
self.resize_transform_size = resize_transform.size
|
|
self.patch_size = vision_tower.visual.conv1.kernel_size[0]
|
|
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(
|
|
"openai/clip-vit-large-patch14",
|
|
crop_size=resize_transform.size,
|
|
size={"shortest_edge": resize_transform.size},
|
|
image_mean=list(normalize_transform.mean),
|
|
image_std=list(normalize_transform.std),
|
|
)
|
|
rank0_print(f"Loaded image processor: {self.image_processor}")
|
|
self.vision_tower = vision_tower.visual
|
|
self.vision_tower.requires_grad_(False)
|
|
|
|
self.is_loaded = True
|
|
|
|
def feature_select(self, image_forward_outs):
|
|
image_features = image_forward_outs[self.select_layer]
|
|
if self.select_feature == "patch":
|
|
image_features = image_features[:, 1:]
|
|
elif self.select_feature == "cls_patch":
|
|
image_features = image_features
|
|
elif self.select_feature == "conv_flatten":
|
|
image_features = image_features.flatten(2).transpose(1, 2)
|
|
else:
|
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
|
return image_features
|
|
|
|
def forward_visual(self, x, output_hidden_states=False):
|
|
if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"):
|
|
return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer))
|
|
else:
|
|
|
|
def forward_openclip(self, x: torch.Tensor):
|
|
features = []
|
|
x = self.conv1(x)
|
|
x = x.reshape(x.shape[0], x.shape[1], -1)
|
|
x = x.permute(0, 2, 1)
|
|
|
|
|
|
x = torch.cat(
|
|
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
|
|
dim=1,
|
|
)
|
|
|
|
x = x + self.positional_embedding.to(x.dtype)
|
|
|
|
x = self.patch_dropout(x)
|
|
x = self.ln_pre(x)
|
|
|
|
x = x.permute(1, 0, 2)
|
|
for r in self.transformer.resblocks:
|
|
x = r(x, attn_mask=None)
|
|
features.append(x)
|
|
return features
|
|
|
|
return forward_openclip(self.vision_tower, x)
|
|
|
|
def forward(self, images):
|
|
if type(images) is list:
|
|
image_features = []
|
|
for image in images:
|
|
image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True)
|
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
image_features.append(image_feature)
|
|
else:
|
|
image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True)
|
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
|
|
return image_features
|
|
|
|
@property
|
|
def dummy_feature(self):
|
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
|
|
@property
|
|
def dtype(self):
|
|
if hasattr(self.vision_tower, "conv1"):
|
|
return self.vision_tower.conv1.weight.dtype
|
|
if hasattr(self.vision_tower, "trunk"):
|
|
return self.vision_tower.trunk.patch_embed.proj.weight.dtype
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def device(self):
|
|
if hasattr(self.vision_tower, "conv1"):
|
|
return self.vision_tower.conv1.weight.device
|
|
if hasattr(self.vision_tower, "trunk"):
|
|
return self.vision_tower.trunk.patch_embed.proj.weight.device
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def config(self):
|
|
return None
|
|
|
|
@property
|
|
def hidden_size(self):
|
|
if self.model_name in HIDDEN_SIZE_DICT:
|
|
return HIDDEN_SIZE_DICT[self.model_name]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def num_patches(self):
|
|
image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0]
|
|
_num_patches = (image_size // self.patch_size) ** 2
|
|
if "cls_patch" in self.select_feature:
|
|
_num_patches += 1
|
|
return _num_patches
|
|
|
|
@property
|
|
def image_size(self):
|
|
return self.resize_transform_size
|
|
|
|
@property
|
|
def num_patches_per_side(self):
|
|
return self.resize_transform_size // self.patch_size
|
|
|