from typing import Optional, List import torch import torch.distributed as dist import torch.nn as nn from mmdet.registry import MODELS from mmengine.model import BaseModule from mmengine.dist import get_dist_info from mmengine.logging import MMLogger import ext.open_clip as open_clip from utils.load_checkpoint import load_checkpoint_with_prefix @MODELS.register_module() class OpenCLIPBackbone(BaseModule): """OpenCLIPBackbone, Please refer to: https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface for the supported models and checkpoints. """ STAGES = 4 def __init__( self, img_size: int = 1024, model_name: str = '', fix: bool = True, fix_layers: Optional[List] = None, init_cfg=None, ): assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ f"{init_cfg['type']} is not supported." pretrained = init_cfg['checkpoint'] super().__init__(init_cfg=None) self.init_cfg = init_cfg self.logger = MMLogger.get_current_instance() rank, world_size = get_dist_info() if world_size > 1: if rank == 0: if init_cfg['type'] == 'clip_pretrain': _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, logger=self.logger) elif init_cfg['type'] == 'image_pretrain': _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) else: pass dist.barrier() # Get the clip model if init_cfg['type'] == 'clip_pretrain': clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, logger=self.logger) elif init_cfg['type'] == 'image_pretrain': clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) elif init_cfg['type'] == 'Pretrained': clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) else: raise NotImplementedError self.out_indices = (0, 1, 2, 3) model_name_lower = model_name.lower() if 'convnext_' in model_name_lower: model_type = 'convnext' if '_base' in model_name_lower: output_channels = [128, 256, 512, 1024] feat_size = 0 elif '_large' in model_name_lower: output_channels = [192, 384, 768, 1536] feat_size = 0 elif '_xxlarge' in model_name_lower: output_channels = [384, 768, 1536, 3072] feat_size = 0 else: raise NotImplementedError(f"{model_name} not supported yet.") elif 'rn' in model_name_lower: model_type = 'resnet' if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: output_channels = [256, 512, 1024, 2048] feat_size = 7 elif model_name_lower == 'rn50x4': output_channels = [320, 640, 1280, 2560] feat_size = 9 elif model_name_lower == 'rn50x16': output_channels = [384, 768, 1536, 3072] feat_size = 12 elif model_name_lower == 'rn50x64': output_channels = [512, 1024, 2048, 4096] feat_size = 14 else: raise NotImplementedError(f"{model_name} not supported yet.") else: raise NotImplementedError(f"{model_name} not supported yet.") self.model_name = model_name self.fix = fix self.model_type = model_type self.output_channels = output_channels self.feat_size = feat_size # Get the visual model if self.model_type == 'resnet': self.stem = nn.Sequential(*[ clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, ]) elif self.model_type == 'convnext': self.stem = clip_model.visual.trunk.stem else: raise ValueError if self.model_type == 'resnet': self.avgpool = clip_model.visual.avgpool elif self.model_type == 'convnext': self.avgpool = nn.Identity() else: raise ValueError self.res_layers = [] for i in range(self.STAGES): if self.model_type == 'resnet': layer_name = f'layer{i + 1}' layer = getattr(clip_model.visual, layer_name) elif self.model_type == 'convnext': layer_name = f'layer{i + 1}' layer = clip_model.visual.trunk.stages[i] else: raise ValueError self.add_module(layer_name, layer) self.res_layers.append(layer_name) if self.model_type == 'resnet': self.norm_pre = nn.Identity() elif self.model_type == 'convnext': self.norm_pre = clip_model.visual.trunk.norm_pre if self.model_type == 'resnet': self.head = clip_model.visual.attnpool elif self.model_type == 'convnext': self.head = nn.Sequential(*[ clip_model.visual.trunk.head, clip_model.visual.head, ]) if self.init_cfg['type'] == 'Pretrained': checkpoint_path = pretrained state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) self.load_state_dict(state_dict, strict=True) self.fix_layers = fix_layers if not self.fix: self.train() for name, param in self.norm_pre.named_parameters(): param.requires_grad = False for name, param in self.head.named_parameters(): param.requires_grad = False if self.fix_layers is not None: for i, layer_name in enumerate(self.res_layers): if i in self.fix_layers: res_layer = getattr(self, layer_name) for name, param in res_layer.named_parameters(): param.requires_grad = False if self.fix: self.train(mode=False) for name, param in self.named_parameters(): param.requires_grad = False def init_weights(self): self.logger.info(f"Init Config for {self.model_name}") self.logger.info(self.init_cfg) def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: if not isinstance(mode, bool): raise ValueError("training mode is expected to be boolean") if self.fix: super().train(mode=False) else: super().train(mode=mode) if self.fix_layers is not None: for i, layer_name in enumerate(self.res_layers): if i in self.fix_layers: res_layer = getattr(self, layer_name) res_layer.train(mode=False) return self def forward_func(self, x): x = self.stem(x) x = self.avgpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x).contiguous() if i in self.out_indices: outs.append(x) return tuple(outs) def get_clip_feature(self, backbone_feat): if self.model_type == 'resnet': return backbone_feat elif self.model_type == 'convnext': return self.norm_pre(backbone_feat) raise NotImplementedError def forward_feat(self, features): if self.model_type == 'convnext': batch, num_query, channel = features.shape features = features.reshape(batch * num_query, channel, 1, 1) features = self.head(features) return features.view(batch, num_query, features.shape[-1]) elif self.model_type == 'resnet': num_query, channel, seven, seven = features.shape features = self.head(features) return features def forward(self, x): if self.fix: with torch.no_grad(): outs = self.forward_func(x) else: outs = self.forward_func(x) return outs def get_text_model(self): return OpenCLIPBackboneText( self.model_name, init_cfg=self.init_cfg ) @MODELS.register_module() class OpenCLIPBackboneText(BaseModule): def __init__( self, model_name: str = '', init_cfg=None, ): assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported." pretrained = init_cfg['checkpoint'] super().__init__(init_cfg=None) self.init_cfg = init_cfg self.logger = MMLogger.get_current_instance() rank, world_size = get_dist_info() if world_size > 1: if rank == 0: _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, logger=self.logger) else: pass dist.barrier() # Get the clip model clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, logger=self.logger) # Get the textual model self.text_tokenizer = open_clip.get_tokenizer(model_name) self.text_transformer = clip_model.transformer self.text_token_embedding = clip_model.token_embedding self.text_pe = clip_model.positional_embedding self.text_ln_final = clip_model.ln_final self.text_proj = clip_model.text_projection self.register_buffer('text_attn_mask', clip_model.attn_mask) self.param_dtype = torch.float32 self.model_name = model_name def init_weights(self): self.logger.info(f"Init Config for {self.model_name}") self.logger.info(self.init_cfg) # Copied from # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343 @torch.no_grad() def forward(self, text): text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device) x = self.text_token_embedding(text_tokens).to(self.param_dtype) x = x + self.text_pe.to(self.param_dtype) x = x.permute(1, 0, 2) x = self.text_transformer(x, attn_mask=self.text_attn_mask) x = x.permute(1, 0, 2) x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj return x