ovsam / app /models /model.py
HarborYuan's picture
Add model
9cc3eb2
raw
history blame
No virus
2.99 kB
import torch.nn.functional as F
from mmengine.model import BaseModel
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
@MODELS.register_module()
class SAMSegmentor(BaseModel):
MASK_THRESHOLD = 0.5
def __init__(
self,
backbone: ConfigType,
neck: ConfigType,
prompt_encoder: ConfigType,
mask_decoder: ConfigType,
data_preprocessor: OptConfigType = None,
fpn_neck: OptConfigType = None,
init_cfg: OptMultiConfig = None,
use_clip_feat: bool = False,
use_head_feat: bool = False,
use_gt_prompt: bool = False,
use_point: bool = False,
enable_backbone: bool = False,
) -> None:
super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone)
self.neck = MODELS.build(neck)
self.pe = MODELS.build(prompt_encoder)
self.mask_decoder = MODELS.build(mask_decoder)
if fpn_neck is not None:
self.fpn_neck = MODELS.build(fpn_neck)
else:
self.fpn_neck = None
self.use_clip_feat = use_clip_feat
self.use_head_feat = use_head_feat
self.use_gt_prompt = use_gt_prompt
self.use_point = use_point
self.enable_backbone = enable_backbone
def extract_feat(self, inputs):
backbone_feat = self.backbone(inputs)
neck_feat = self.neck(backbone_feat)
if self.fpn_neck is not None:
fpn_feat = self.fpn_neck(backbone_feat)
else:
fpn_feat = None
return dict(
backbone_feat=backbone_feat,
neck_feat=neck_feat,
fpn_feat=fpn_feat
)
def extract_masks(self, feat_cache, prompts):
sparse_embed, dense_embed = self.pe(
prompts,
image_size=(1024, 1024),
with_points='point_coords' in prompts,
with_bboxes='bboxes' in prompts,
)
kwargs = dict()
if self.enable_backbone:
kwargs['backbone_feats'] = feat_cache['backbone_feat']
kwargs['backbone'] = self.backbone
kwargs['fpn_feats'] = feat_cache['fpn_feat']
low_res_masks, iou_predictions, cls_pred = self.mask_decoder(
image_embeddings=feat_cache['neck_feat'],
image_pe=self.pe.get_dense_pe(),
sparse_prompt_embeddings=sparse_embed,
dense_prompt_embeddings=dense_embed,
multi_mask_output=False,
**kwargs
)
masks = F.interpolate(
low_res_masks,
scale_factor=4.,
mode='bilinear',
align_corners=False,
)
masks = masks.sigmoid()
cls_pred = cls_pred.softmax(-1)[..., :-1]
return masks.detach().cpu().numpy(), cls_pred.detach().cpu()
def forward(self, inputs):
return inputs