import torch.nn.functional as F from mmengine.model import BaseModel from mmdet.registry import MODELS from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig @MODELS.register_module() class SAMSegmentor(BaseModel): MASK_THRESHOLD = 0.5 def __init__( self, backbone: ConfigType, neck: ConfigType, prompt_encoder: ConfigType, mask_decoder: ConfigType, data_preprocessor: OptConfigType = None, fpn_neck: OptConfigType = None, init_cfg: OptMultiConfig = None, use_clip_feat: bool = False, use_head_feat: bool = False, use_gt_prompt: bool = False, use_point: bool = False, enable_backbone: bool = False, ) -> None: super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.backbone = MODELS.build(backbone) self.neck = MODELS.build(neck) self.pe = MODELS.build(prompt_encoder) self.mask_decoder = MODELS.build(mask_decoder) if fpn_neck is not None: self.fpn_neck = MODELS.build(fpn_neck) else: self.fpn_neck = None self.use_clip_feat = use_clip_feat self.use_head_feat = use_head_feat self.use_gt_prompt = use_gt_prompt self.use_point = use_point self.enable_backbone = enable_backbone def extract_feat(self, inputs): backbone_feat = self.backbone(inputs) neck_feat = self.neck(backbone_feat) if self.fpn_neck is not None: fpn_feat = self.fpn_neck(backbone_feat) else: fpn_feat = None return dict( backbone_feat=backbone_feat, neck_feat=neck_feat, fpn_feat=fpn_feat ) def extract_masks(self, feat_cache, prompts): sparse_embed, dense_embed = self.pe( prompts, image_size=(1024, 1024), with_points='point_coords' in prompts, with_bboxes='bboxes' in prompts, ) kwargs = dict() if self.enable_backbone: kwargs['backbone_feats'] = feat_cache['backbone_feat'] kwargs['backbone'] = self.backbone kwargs['fpn_feats'] = feat_cache['fpn_feat'] low_res_masks, iou_predictions, cls_pred = self.mask_decoder( image_embeddings=feat_cache['neck_feat'], image_pe=self.pe.get_dense_pe(), sparse_prompt_embeddings=sparse_embed, dense_prompt_embeddings=dense_embed, multi_mask_output=False, **kwargs ) masks = F.interpolate( low_res_masks, scale_factor=4., mode='bilinear', align_corners=False, ) masks = masks.sigmoid() cls_pred = cls_pred.softmax(-1)[..., :-1] return masks.detach().cpu().numpy(), cls_pred.detach().cpu() def forward(self, inputs): return inputs