# Copyright (c) OpenMMLab. All rights reserved. from typing import List import torch import torch.nn.functional as F from mmengine.structures import InstanceData, PixelData from torch import Tensor from mmdet.evaluation.functional import INSTANCE_OFFSET from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.structures.mask import mask2bbox from mmdet.utils import OptConfigType, OptMultiConfig from mmdet.models.seg_heads.panoptic_fusion_heads.base_panoptic_fusion_head import BasePanopticFusionHead @MODELS.register_module() class OMGFusionHead(BasePanopticFusionHead): def __init__( self, num_things_classes: int = 80, num_stuff_classes: int = 53, test_cfg: OptConfigType = None, loss_panoptic: OptConfigType = None, init_cfg: OptMultiConfig = None, **kwargs ): super().__init__( num_things_classes=num_things_classes, num_stuff_classes=num_stuff_classes, test_cfg=test_cfg, loss_panoptic=loss_panoptic, init_cfg=init_cfg, **kwargs) def loss(self, **kwargs): """MaskFormerFusionHead has no training loss.""" return dict() def panoptic_postprocess(self, mask_cls: Tensor, mask_pred: Tensor) -> PixelData: """Panoptic segmengation inference. Args: mask_cls (Tensor): Classfication outputs of shape (num_queries, cls_out_channels) for a image. Note `cls_out_channels` should includes background. mask_pred (Tensor): Mask outputs of shape (num_queries, h, w) for a image. Returns: :obj:`PixelData`: Panoptic segment result of shape \ (h, w), each element in Tensor means: \ ``segment_id = _cls + instance_id * INSTANCE_OFFSET``. """ object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) iou_thr = self.test_cfg.get('iou_thr', 0.8) filter_low_score = self.test_cfg.get('filter_low_score', False) scores, labels = F.softmax(mask_cls, dim=-1).max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.num_classes) & (scores > object_mask_thr) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.full((h, w), self.num_classes, dtype=torch.int32, device=cur_masks.device) if cur_masks.shape[0] == 0: # We didn't detect any mask :( pass else: cur_mask_ids = cur_prob_masks.argmax(0) instance_id = 1 for k in range(cur_classes.shape[0]): pred_class = int(cur_classes[k].item()) isthing = pred_class < self.num_things_classes mask = cur_mask_ids == k mask_area = mask.sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() if filter_low_score: mask = mask & (cur_masks[k] >= 0.5) if mask_area > 0 and original_area > 0: if mask_area / original_area < iou_thr: continue if not isthing: # different stuff regions of same class will be # merged here, and stuff share the instance_id 0. panoptic_seg[mask] = pred_class else: panoptic_seg[mask] = ( pred_class + instance_id * INSTANCE_OFFSET) instance_id += 1 return PixelData(sem_seg=panoptic_seg[None]) def semantic_postprocess(self, mask_cls: Tensor, mask_pred: Tensor) -> PixelData: """Semantic segmengation postprocess. Args: mask_cls (Tensor): Classfication outputs of shape (num_queries, cls_out_channels) for a image. Note `cls_out_channels` should includes background. mask_pred (Tensor): Mask outputs of shape (num_queries, h, w) for a image. Returns: :obj:`PixelData`: Semantic segment result. """ # TODO add semantic segmentation result raise NotImplementedError def instance_postprocess(self, mask_cls: Tensor, mask_pred: Tensor) -> InstanceData: """Instance segmengation postprocess. Args: mask_cls (Tensor): Classfication outputs of shape (num_queries, cls_out_channels) for a image. Note `cls_out_channels` should includes background. mask_pred (Tensor): Mask outputs of shape (num_queries, h, w) for a image. Returns: :obj:`InstanceData`: Instance segmentation results. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ max_per_image = self.test_cfg.get('max_per_image', 100) num_queries = mask_cls.shape[0] # shape (num_queries, num_class) scores = F.softmax(mask_cls, dim=-1)[:, :-1] # shape (num_queries * num_class, ) labels = torch.arange(self.num_classes, device=mask_cls.device). \ unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) scores_per_image, top_indices = scores.flatten(0, 1).topk( max_per_image, sorted=False) labels_per_image = labels[top_indices] query_indices = top_indices // self.num_classes mask_pred = mask_pred[query_indices] # extract things is_thing = labels_per_image < self.num_things_classes scores_per_image = scores_per_image[is_thing] labels_per_image = labels_per_image[is_thing] mask_pred = mask_pred[is_thing] mask_pred_binary = (mask_pred > 0).float() mask_scores_per_image = (mask_pred.sigmoid() * mask_pred_binary).flatten(1).sum(1) / ( mask_pred_binary.flatten(1).sum(1) + 1e-6) det_scores = scores_per_image * mask_scores_per_image mask_pred_binary = mask_pred_binary.bool() bboxes = mask2bbox(mask_pred_binary) results = InstanceData() results.bboxes = bboxes results.labels = labels_per_image results.scores = det_scores results.masks = mask_pred_binary return results def proposal_postprocess(self, mask_score: Tensor, mask_pred: Tensor) -> InstanceData: max_per_image = self.test_cfg.get('num_proposals', 10) h, w = mask_pred.shape[-2:] # shape (num_queries, num_class) scores = mask_score.sigmoid().squeeze(-1) scores_per_image, top_indices = scores.topk(max_per_image, sorted=True) mask_selected = mask_pred[top_indices] proposals = [] for idx in range(len(mask_selected)): mask = mask_selected[len(mask_selected) - idx - 1] proposals.append(mask.sigmoid() > .5) seg_map = torch.stack(proposals) return seg_map def predict(self, mask_cls_results: Tensor, mask_pred_results: Tensor, batch_data_samples: SampleList, iou_results=None, rescale: bool = False, **kwargs) -> List[dict]: """Test segment without test-time aumengtation. Only the output of last decoder layers was used. Args: mask_cls_results (Tensor): Mask classification logits, shape (batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. mask_pred_results (Tensor): Mask logits, shape (batch_size, num_queries, h, w). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. iou_results: None rescale (bool): If True, return boxes in original image space. Default False. Returns: list[dict]: Instance segmentation \ results and panoptic segmentation results for each \ image. .. code-block:: none [ { 'pan_results': PixelData, 'ins_results': InstanceData, # semantic segmentation results are not supported yet 'sem_results': PixelData }, ... ] """ batch_img_metas = [ data_sample.metainfo for data_sample in batch_data_samples ] panoptic_on = self.test_cfg.get('panoptic_on', True) semantic_on = self.test_cfg.get('semantic_on', False) instance_on = self.test_cfg.get('instance_on', False) proposal_on = self.test_cfg.get('proposal_on', False) assert not semantic_on, 'segmantic segmentation ' \ 'results are not supported yet.' results = [] idx = 0 for mask_cls_result, mask_pred_result, meta in zip( mask_cls_results, mask_pred_results, batch_img_metas): # remove padding img_height, img_width = meta['img_shape'][:2] mask_pred_result = mask_pred_result.to(mask_cls_results.device) mask_pred_result = mask_pred_result[:, :img_height, :img_width] if rescale: # return result in original resolution ori_height, ori_width = meta['ori_shape'][:2] mask_pred_result = F.interpolate( mask_pred_result[:, None], size=(ori_height, ori_width), mode='bilinear', align_corners=False)[:, 0] result = dict() if panoptic_on: pan_results = self.panoptic_postprocess( mask_cls_result, mask_pred_result ) result['pan_results'] = pan_results if instance_on: ins_results = self.instance_postprocess( mask_cls_result, mask_pred_result ) result['ins_results'] = ins_results if semantic_on: sem_results = self.semantic_postprocess( mask_cls_result, mask_pred_result ) result['sem_results'] = sem_results if proposal_on: pro_results = self.proposal_postprocess( iou_results[idx], mask_pred_result ) result['pro_results'] = pro_results results.append(result) idx += 1 return results