import torch import torch.nn as nn import random class MaskedDrop(nn.Module): def __init__(self, model_args): super().__init__() self.mode = model_args.mm_mask_drop_mode self.skip_percentage = model_args.mm_mask_drop_skip_percentage self.ratio = model_args.mm_mask_drop_ratio self.ratio_upper = model_args.mm_mask_drop_ratio_upper self.ratio_lower = model_args.mm_mask_drop_ratio_lower def forward(self, image_features, *args, **kwargs): if not self.training: return image_features if self.skip_percentage > random.random(): return image_features masked_features = [] for image_feature in image_features: num_tokens = image_feature.shape[0] if self.mode == 'fixed': num_keep = int(num_tokens * self.ratio) masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) elif self.mode == 'range': num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) elif self.mode == 'cls_only': masked_features.append(image_feature[0:1]) else: raise ValueError(f'Unexpected masked drop mode: {self.mode}') if self.mode not in ['range'] and \ (type(image_features) is not list or self.mode in ['cls_only']): masked_features = torch.stack(masked_features, dim=0) return masked_features @property def config(self): return { 'mm_resampler_type': 'masked_drop', 'mm_mask_drop_mode': self.mode, 'mm_mask_drop_skip_percentage': self.skip_percentage, 'mm_mask_drop_ratio': self.ratio, 'mm_mask_drop_ratio_upper': self.ratio_upper, 'mm_mask_drop_ratio_lower': self.ratio_lower, } def random_masking(self, x, len_keep): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore