THUdyh's picture
update space
d9c19b7 verified
raw
history blame
No virus
1.19 kB
import torch
from .masked_drop import MaskedDrop
from .spatial_pool import SpatialPool
from .qformer import Qformer
from .vlm_attention import VlmAttention
from .perceiver import DynamicCompressor
class IdentityMap(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_resampler_type": None}
def build_vision_resampler(model_args, delay_load=False, **kwargs):
# import pdb;pdb.set_trace()
resampler_type = getattr(model_args, 'mm_resampler_type', None)
if resampler_type == 'masked_drop':
return MaskedDrop(model_args)
elif resampler_type == 'spatial_pool':
return SpatialPool(model_args, **kwargs)
elif resampler_type == 'qformer':
return Qformer(model_args, **kwargs)
elif resampler_type == 'vlm_attention':
return VlmAttention(model_args,**kwargs)
elif resampler_type == 'dynamic_compressor':
return DynamicCompressor(model_args, **kwargs)
elif resampler_type is None:
return IdentityMap()
else:
raise ValueError(f'Unknown resampler type: {resampler_type}')