from typing import Tuple, Literal import torch from mmengine import MMLogger from mmdet.registry import MODELS from mmengine.model import BaseModule from mmengine.structures import InstanceData from ext.sam import PromptEncoder from ext.meta.sam_meta import meta_dict, checkpoint_dict from utils.load_checkpoint import load_checkpoint_with_prefix @MODELS.register_module() class SAMPromptEncoder(BaseModule): def __init__( self, model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', fix: bool = True, init_cfg=None, ): assert init_cfg is not None and init_cfg['type'] == 'sam_pretrain', f"{init_cfg['type']} is not supported." pretrained = init_cfg['checkpoint'] super().__init__(init_cfg=None) self.init_cfg = init_cfg self.logger = MMLogger.get_current_instance() backbone_meta = meta_dict[model_name] checkpoint_path = checkpoint_dict[pretrained] prompt_encoder = PromptEncoder( embed_dim=256, image_embedding_size=(backbone_meta['image_embedding_size'], backbone_meta['image_embedding_size']), input_image_size=(backbone_meta['image_size'], backbone_meta['image_size']), mask_in_chans=16, ) state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='prompt_encoder') prompt_encoder.load_state_dict(state_dict, strict=True) # meta self.embed_dim = prompt_encoder.embed_dim self.input_image_size = prompt_encoder.input_image_size self.image_embedding_size = prompt_encoder.image_embedding_size self.num_point_embeddings = 4 self.mask_input_size = prompt_encoder.mask_input_size # positional encoding self.pe_layer = prompt_encoder.pe_layer # mask encoding self.mask_downscaling = prompt_encoder.mask_downscaling self.no_mask_embed = prompt_encoder.no_mask_embed # point encoding self.point_embeddings = prompt_encoder.point_embeddings self.not_a_point_embed = prompt_encoder.not_a_point_embed self.fix = fix if self.fix: self.train(mode=False) for name, param in self.named_parameters(): param.requires_grad = False @property def device(self): return self.no_mask_embed.weight.device def init_weights(self): self.logger.info(f"Init Config for {self.__class__.__name__}") self.logger.info(self.init_cfg) def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: if not isinstance(mode, bool): raise ValueError("training mode is expected to be boolean") if self.fix: super().train(mode=False) else: super().train(mode=mode) return self def _embed_boxes(self, bboxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: """Embeds box prompts.""" bboxes = bboxes + 0.5 # Shift to center of pixel coords = bboxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def get_dense_pe(self) -> torch.Tensor: return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def forward( self, instances: InstanceData, image_size: Tuple[int, int], with_points: bool = False, with_bboxes: bool = False, with_masks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: assert with_points or with_bboxes or with_masks bs = len(instances) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.device) if with_points: assert 'point_coords' in instances coords = instances.point_coords labels = torch.ones_like(coords)[:, :, 0] point_embeddings = self._embed_points(coords, labels, pad=not with_bboxes) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if with_bboxes: assert 'bboxes' in instances box_embeddings = self._embed_boxes( instances.bboxes, image_size=image_size ) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if with_masks: assert 'masks' in instances dense_embeddings = self._embed_masks(instances.masks.masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) return sparse_embeddings, dense_embeddings