OMG_Seg / seg /models /fusion_head /omgseg_fusionhead.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame contribute delete
No virus
11.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData, PixelData
from torch import Tensor
from mmdet.evaluation.functional import INSTANCE_OFFSET
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.mask import mask2bbox
from mmdet.utils import OptConfigType, OptMultiConfig
from mmdet.models.seg_heads.panoptic_fusion_heads.base_panoptic_fusion_head import BasePanopticFusionHead
@MODELS.register_module()
class OMGFusionHead(BasePanopticFusionHead):
def __init__(
self,
num_things_classes: int = 80,
num_stuff_classes: int = 53,
test_cfg: OptConfigType = None,
loss_panoptic: OptConfigType = None,
init_cfg: OptMultiConfig = None,
**kwargs
):
super().__init__(
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
test_cfg=test_cfg,
loss_panoptic=loss_panoptic,
init_cfg=init_cfg,
**kwargs)
def loss(self, **kwargs):
"""MaskFormerFusionHead has no training loss."""
return dict()
def panoptic_postprocess(self, mask_cls: Tensor,
mask_pred: Tensor) -> PixelData:
"""Panoptic segmengation inference.
Args:
mask_cls (Tensor): Classfication outputs of shape
(num_queries, cls_out_channels) for a image.
Note `cls_out_channels` should includes
background.
mask_pred (Tensor): Mask outputs of shape
(num_queries, h, w) for a image.
Returns:
:obj:`PixelData`: Panoptic segment result of shape \
(h, w), each element in Tensor means: \
``segment_id = _cls + instance_id * INSTANCE_OFFSET``.
"""
object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8)
iou_thr = self.test_cfg.get('iou_thr', 0.8)
filter_low_score = self.test_cfg.get('filter_low_score', False)
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
mask_pred = mask_pred.sigmoid()
keep = labels.ne(self.num_classes) & (scores > object_mask_thr)
cur_scores = scores[keep]
cur_classes = labels[keep]
cur_masks = mask_pred[keep]
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
h, w = cur_masks.shape[-2:]
panoptic_seg = torch.full((h, w),
self.num_classes,
dtype=torch.int32,
device=cur_masks.device)
if cur_masks.shape[0] == 0:
# We didn't detect any mask :(
pass
else:
cur_mask_ids = cur_prob_masks.argmax(0)
instance_id = 1
for k in range(cur_classes.shape[0]):
pred_class = int(cur_classes[k].item())
isthing = pred_class < self.num_things_classes
mask = cur_mask_ids == k
mask_area = mask.sum().item()
original_area = (cur_masks[k] >= 0.5).sum().item()
if filter_low_score:
mask = mask & (cur_masks[k] >= 0.5)
if mask_area > 0 and original_area > 0:
if mask_area / original_area < iou_thr:
continue
if not isthing:
# different stuff regions of same class will be
# merged here, and stuff share the instance_id 0.
panoptic_seg[mask] = pred_class
else:
panoptic_seg[mask] = (
pred_class + instance_id * INSTANCE_OFFSET)
instance_id += 1
return PixelData(sem_seg=panoptic_seg[None])
def semantic_postprocess(self, mask_cls: Tensor,
mask_pred: Tensor) -> PixelData:
"""Semantic segmengation postprocess.
Args:
mask_cls (Tensor): Classfication outputs of shape
(num_queries, cls_out_channels) for a image.
Note `cls_out_channels` should includes
background.
mask_pred (Tensor): Mask outputs of shape
(num_queries, h, w) for a image.
Returns:
:obj:`PixelData`: Semantic segment result.
"""
# TODO add semantic segmentation result
raise NotImplementedError
def instance_postprocess(self, mask_cls: Tensor,
mask_pred: Tensor) -> InstanceData:
"""Instance segmengation postprocess.
Args:
mask_cls (Tensor): Classfication outputs of shape
(num_queries, cls_out_channels) for a image.
Note `cls_out_channels` should includes
background.
mask_pred (Tensor): Mask outputs of shape
(num_queries, h, w) for a image.
Returns:
:obj:`InstanceData`: Instance segmentation results.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, H, W).
"""
max_per_image = self.test_cfg.get('max_per_image', 100)
num_queries = mask_cls.shape[0]
# shape (num_queries, num_class)
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
# shape (num_queries * num_class, )
labels = torch.arange(self.num_classes, device=mask_cls.device). \
unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
scores_per_image, top_indices = scores.flatten(0, 1).topk(
max_per_image, sorted=False)
labels_per_image = labels[top_indices]
query_indices = top_indices // self.num_classes
mask_pred = mask_pred[query_indices]
# extract things
is_thing = labels_per_image < self.num_things_classes
scores_per_image = scores_per_image[is_thing]
labels_per_image = labels_per_image[is_thing]
mask_pred = mask_pred[is_thing]
mask_pred_binary = (mask_pred > 0).float()
mask_scores_per_image = (mask_pred.sigmoid() *
mask_pred_binary).flatten(1).sum(1) / (
mask_pred_binary.flatten(1).sum(1) + 1e-6)
det_scores = scores_per_image * mask_scores_per_image
mask_pred_binary = mask_pred_binary.bool()
bboxes = mask2bbox(mask_pred_binary)
results = InstanceData()
results.bboxes = bboxes
results.labels = labels_per_image
results.scores = det_scores
results.masks = mask_pred_binary
return results
def proposal_postprocess(self, mask_score: Tensor, mask_pred: Tensor) -> InstanceData:
max_per_image = self.test_cfg.get('num_proposals', 10)
h, w = mask_pred.shape[-2:]
# shape (num_queries, num_class)
scores = mask_score.sigmoid().squeeze(-1)
scores_per_image, top_indices = scores.topk(max_per_image, sorted=True)
mask_selected = mask_pred[top_indices]
proposals = []
for idx in range(len(mask_selected)):
mask = mask_selected[len(mask_selected) - idx - 1]
proposals.append(mask.sigmoid() > .5)
seg_map = torch.stack(proposals)
return seg_map
def predict(self,
mask_cls_results: Tensor,
mask_pred_results: Tensor,
batch_data_samples: SampleList,
iou_results=None,
rescale: bool = False,
**kwargs) -> List[dict]:
"""Test segment without test-time aumengtation.
Only the output of last decoder layers was used.
Args:
mask_cls_results (Tensor): Mask classification logits,
shape (batch_size, num_queries, cls_out_channels).
Note `cls_out_channels` should includes background.
mask_pred_results (Tensor): Mask logits, shape
(batch_size, num_queries, h, w).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
iou_results: None
rescale (bool): If True, return boxes in
original image space. Default False.
Returns:
list[dict]: Instance segmentation \
results and panoptic segmentation results for each \
image.
.. code-block:: none
[
{
'pan_results': PixelData,
'ins_results': InstanceData,
# semantic segmentation results are not supported yet
'sem_results': PixelData
},
...
]
"""
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
panoptic_on = self.test_cfg.get('panoptic_on', True)
semantic_on = self.test_cfg.get('semantic_on', False)
instance_on = self.test_cfg.get('instance_on', False)
proposal_on = self.test_cfg.get('proposal_on', False)
assert not semantic_on, 'segmantic segmentation ' \
'results are not supported yet.'
results = []
idx = 0
for mask_cls_result, mask_pred_result, meta in zip(
mask_cls_results, mask_pred_results, batch_img_metas):
# remove padding
img_height, img_width = meta['img_shape'][:2]
mask_pred_result = mask_pred_result.to(mask_cls_results.device)
mask_pred_result = mask_pred_result[:, :img_height, :img_width]
if rescale:
# return result in original resolution
ori_height, ori_width = meta['ori_shape'][:2]
mask_pred_result = F.interpolate(
mask_pred_result[:, None],
size=(ori_height, ori_width),
mode='bilinear',
align_corners=False)[:, 0]
result = dict()
if panoptic_on:
pan_results = self.panoptic_postprocess(
mask_cls_result, mask_pred_result
)
result['pan_results'] = pan_results
if instance_on:
ins_results = self.instance_postprocess(
mask_cls_result, mask_pred_result
)
result['ins_results'] = ins_results
if semantic_on:
sem_results = self.semantic_postprocess(
mask_cls_result, mask_pred_result
)
result['sem_results'] = sem_results
if proposal_on:
pro_results = self.proposal_postprocess(
iou_results[idx], mask_pred_result
)
result['pro_results'] = pro_results
results.append(result)
idx += 1
return results