import torch import torch.nn as nn import math import torch.nn.functional as F class DynamicCompressor(nn.Module): def __init__(self, model_args, vision_tower): super().__init__() self.out_channels = vision_tower.hidden_size self.mid_channel = 256 self.vlm_query_projector = nn.Linear(self.out_channels, self.mid_channel) self.vlm_key_projector = nn.Linear(self.out_channels, self.mid_channel) def downsample(self, x): return F.avg_pool2d(x, 2, 2) def downsample_4(self, x): return F.avg_pool2d(x, 4, 4) def forward(self, image_features, forward_type, image_size=None): if image_size is None: ori_W = int(math.sqrt(image_features.shape[1])) ori_H = int(ori_W) else: ori_H, ori_W = image_size T, N, C = image_features.shape image_features = image_features.view(T, ori_H, ori_W, C).permute(0, 3, 1, 2) # T, C, H, W if forward_type == 'video': image_features_pool = self.downsample(image_features) image_feature_attn = image_features.reshape(T, C, ori_H // 2, 2, ori_W // 2, 2).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 2 * ori_W // 2, 4, C) new_image_size = (ori_H // 2, ori_W // 2) elif forward_type == 'image' or forward_type == 'text': image_features_pool = image_features image_feature_attn = image_features.reshape(T, C, ori_H, 1, ori_W, 1).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H * ori_W, 1, C) new_image_size = (ori_H, ori_W) elif forward_type == 'video_long': image_features_pool = self.downsample_4(image_features) image_feature_attn = image_features.reshape(T, C, ori_H // 4, 4, ori_W // 4, 4).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 4 * ori_W // 4, 16, C) new_image_size = (ori_H // 4, ori_W // 4) else: raise NotImplementedError image_features_pool = image_features_pool.flatten(2).permute(0, 2, 1) # T, H*W, C new_t, new_p, _ = image_features_pool.shape image_query = self.vlm_query_projector(image_features_pool).reshape(new_t*new_p, self.mid_channel) image_key = self.vlm_key_projector(image_feature_attn).reshape(new_t*new_p, -1, self.mid_channel) image_value = image_feature_attn.reshape(new_t*new_p, -1, self.out_channels) image_attn = image_query[:,None] @ (image_key.transpose(-1,-2) / (image_key.shape[-1]**0.5)) image_attn = image_attn.nan_to_num() attn_feat = (image_attn.softmax(-1) @ image_value).mean(1).reshape(new_t, new_p, C) image_features_pool = image_features_pool + attn_feat return image_features_pool, new_image_size @property def config(self): return { 'mm_resampler_type': 'dynamic_compressor', 'mm_out_channels': self.out_channels, } @property def hidden_size(self): return self.out_channels