# Copyright (c) OpenMMLab. All rights reserved. import copy import os from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from mmcv.cnn import Conv2d from mmcv.ops import point_sample from mmdet.models import Mask2FormerTransformerDecoder, inverse_sigmoid, coordinate_to_encoding from mmdet.structures.bbox import bbox_xyxy_to_cxcywh from mmengine import print_log from mmengine.dist import get_dist_info from mmengine.model import caffe2_xavier_init, ModuleList from mmengine.structures import InstanceData, PixelData from torch import Tensor from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures import SampleList, TrackDataSample from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptMultiConfig, reduce_mean) from mmdet.models.layers import SinePositionalEncoding3D from mmdet.models.utils import multi_apply, preprocess_panoptic_gt, get_uncertain_point_coords_with_randomness from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead from seg.models.utils import preprocess_video_panoptic_gt, mask_pool from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix @MODELS.register_module() class Mask2FormerVideoHead(AnchorFreeHead): """Implements the Mask2Former head. See `Masked-attention Mask Transformer for Universal Image Segmentation `_ for details. Args: in_channels (list[int]): Number of channels in the input feature map. feat_channels (int): Number of channels for features. out_channels (int): Number of channels for output. num_things_classes (int): Number of things. num_stuff_classes (int): Number of stuff. num_queries (int): Number of query in Transformer decoder. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel decoder. Defaults to None. enforce_decoder_input_project (bool, optional): Whether to add a layer to change the embed_dim of tranformer encoder in pixel decoder to the embed_dim of transformer decoder. Defaults to False. transformer_decoder (:obj:`ConfigDict` or dict): Config for transformer decoder. Defaults to None. positional_encoding (:obj:`ConfigDict` or dict): Config for transformer decoder position encoding. Defaults to dict(num_feats=128, normalize=True). loss_cls (:obj:`ConfigDict` or dict): Config of the classification loss. Defaults to None. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. Defaults to None. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. Defaults to None. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of Mask2Former head. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of Mask2Former head. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: List[int], feat_channels: int, out_channels: int, num_things_classes: int = 80, num_stuff_classes: int = 53, num_queries: int = 100, num_transformer_feat_level: int = 3, pixel_decoder: ConfigType = ..., enforce_decoder_input_project: bool = False, transformer_decoder: ConfigType = ..., positional_encoding: ConfigType = None, loss_cls: ConfigType = None, loss_mask: ConfigType = None, loss_dice: ConfigType = None, loss_iou: ConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None, # ov configs sphere_cls: bool = False, ov_classifier_name: Optional[str] = None, logit: Optional[int] = None, # box sup matching_whole_map: bool = False, # box query enable_box_query: bool = False, group_assigner: OptConfigType = None, **kwargs) -> None: super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) self.num_things_classes = num_things_classes self.num_stuff_classes = num_stuff_classes self.num_classes = self.num_things_classes + self.num_stuff_classes self.num_queries = num_queries self.num_transformer_feat_level = num_transformer_feat_level self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads self.num_transformer_decoder_layers = transformer_decoder.num_layers assert pixel_decoder.encoder.layer_cfg. \ self_attn_cfg.num_levels == num_transformer_feat_level pixel_decoder_ = copy.deepcopy(pixel_decoder) pixel_decoder_.update( in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) self.pixel_decoder = MODELS.build(pixel_decoder_) self.transformer_decoder = Mask2FormerTransformerDecoder( **transformer_decoder) self.decoder_embed_dims = self.transformer_decoder.embed_dims self.decoder_input_projs = ModuleList() # from low resolution to high resolution for _ in range(num_transformer_feat_level): if (self.decoder_embed_dims != feat_channels or enforce_decoder_input_project): self.decoder_input_projs.append( Conv2d( feat_channels, self.decoder_embed_dims, kernel_size=1)) else: self.decoder_input_projs.append(nn.Identity()) self.decoder_positional_encoding = SinePositionalEncoding3D( **positional_encoding) self.query_embed = nn.Embedding(self.num_queries, feat_channels) self.query_feat = nn.Embedding(self.num_queries, feat_channels) # from low resolution to high resolution self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) if not sphere_cls: self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) self.mask_embed = nn.Sequential( nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, out_channels)) if loss_iou is not None: self.iou_embed = nn.Sequential( nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, 1)) else: self.iou_embed = None self.test_cfg = test_cfg self.train_cfg = train_cfg if train_cfg: self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) self.sampler = TASK_UTILS.build( self.train_cfg['sampler'], default_args=dict(context=self)) self.num_points = self.train_cfg.get('num_points', 12544) self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) self.importance_sample_ratio = self.train_cfg.get( 'importance_sample_ratio', 0.75) self.class_weight = loss_cls.class_weight self.loss_cls = MODELS.build(loss_cls) self.loss_mask = MODELS.build(loss_mask) self.loss_dice = MODELS.build(loss_dice) if loss_iou is not None: self.loss_iou = MODELS.build(loss_iou) else: self.loss_iou = None # prepare OV things # OV cls embed if sphere_cls: rank, world_size = get_dist_info() if ov_classifier_name is None: _dim = 1024 # temporally hard code cls_embed = torch.empty(self.num_classes, _dim) torch.nn.init.orthogonal_(cls_embed) cls_embed = cls_embed[:, None] else: # ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") ov_path = os.path.join(os.path.expanduser('./models/'), f"{ov_classifier_name}.pth") cls_embed = torch.load(ov_path) cls_embed_norm = cls_embed.norm(p=2, dim=-1) assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) if self.loss_cls and self.loss_cls.use_sigmoid: pass else: _dim = cls_embed.size(2) _prototypes = cls_embed.size(1) # if rank == 0: # back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') # # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True) # else: # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') # if world_size > 1: # dist.broadcast(back_token, src=0) # back_token = back_token.to(device='cpu') back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu') cls_embed = torch.cat([ cls_embed, back_token.repeat(_prototypes, 1)[None] ], dim=0) self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) # cls embd proj cls_embed_dim = self.cls_embed.size(0) self.cls_proj = nn.Sequential( nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, cls_embed_dim) ) # Haobo Yuan: # For the logit_scale, I refer to this issue. # https://github.com/openai/CLIP/issues/46#issuecomment-945062212 # https://github.com/openai/CLIP/issues/46#issuecomment-782558799 # Based on my understanding, it is a mistake of CLIP. # Because they mention that they refer to InstDisc (Wu, 2018) paper. # InstDisc set a non-learnable temperature to np.log(1 / 0.07). # 4.6052 is np.log(1 / 0.01) # np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01) if logit is None: logit_scale = torch.tensor(4.6052, dtype=torch.float32) else: logit_scale = torch.tensor(logit, dtype=torch.float32) self.register_buffer('logit_scale', logit_scale, persistent=False) # Mask Pooling self.mask_pooling = mask_pool self.mask_pooling_proj = nn.Sequential( nn.LayerNorm(feat_channels), nn.Linear(feat_channels, feat_channels) ) # box inst self.matching_whole_map = matching_whole_map # enable box query self.enable_box_query = enable_box_query if self.enable_box_query: self.num_mask_tokens = 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, feat_channels) self.pb_embedding = nn.Embedding(2, feat_channels) self.pos_linear = nn.Linear(2 * feat_channels, feat_channels) def init_weights(self) -> None: if self.init_cfg['type'] == 'Pretrained': checkpoint_path = self.init_cfg['checkpoint'] state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) msg = self.load_state_dict(state_dict, strict=False) print_log(f"m: {msg[0]} \n u: {msg[1]}", logger='current') return None for m in self.decoder_input_projs: if isinstance(m, Conv2d): caffe2_xavier_init(m, bias=0) self.pixel_decoder.init_weights() for p in self.transformer_decoder.parameters(): if p.dim() > 1: nn.init.xavier_normal_(p) def preprocess_gt( self, batch_gt_instances: InstanceList, batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList: """Preprocess the ground truth for all images. Args: batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``labels``, each is ground truth labels of each bbox, with shape (num_gts, ) and ``masks``, each is ground truth masks of each instances of a image, shape (num_gts, h, w). batch_gt_semantic_segs (list[Optional[PixelData]]): Ground truth of semantic segmentation, each with the shape (1, h, w). [0, num_thing_class - 1] means things, [num_thing_class, num_class-1] means stuff, 255 means VOID. It's None when training instance segmentation. Returns: list[obj:`InstanceData`]: each contains the following keys - labels (Tensor): Ground truth class indices\ for a image, with shape (n, ), n is the sum of\ number of stuff type and number of instance in a image. - masks (Tensor): Ground truth mask for a\ image, with shape (n, h, w). """ num_things_list = [self.num_things_classes] * len(batch_gt_instances) num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances) if isinstance(batch_gt_instances[0], List): gt_labels_list = [ [torch.stack([torch.ones_like(gt_instances['labels']) * frame_id, gt_instances['labels']], dim=1) for frame_id, gt_instances in enumerate(gt_vid_instances)] for gt_vid_instances in batch_gt_instances ] gt_labels_list = [torch.cat(gt_labels, dim=0) for gt_labels in gt_labels_list] gt_masks_list = [ [gt_instances['masks'] for gt_instances in gt_vid_instances] for gt_vid_instances in batch_gt_instances ] gt_semantic_segs = [ [None if gt_semantic_seg is None else gt_semantic_seg.sem_seg for gt_semantic_seg in gt_vid_semantic_segs] for gt_vid_semantic_segs in batch_gt_semantic_segs ] if gt_semantic_segs[0][0] is None: gt_semantic_segs = [None] * len(batch_gt_instances) else: gt_semantic_segs = [torch.stack(gt_sem_seg, dim=0) for gt_sem_seg in gt_semantic_segs] gt_instance_ids_list = [ [torch.stack([torch.ones_like(gt_instances['instances_ids']) * frame_id, gt_instances['instances_ids']], dim=1) for frame_id, gt_instances in enumerate(gt_vid_instances)] for gt_vid_instances in batch_gt_instances ] gt_instance_ids_list = [torch.cat(gt_instance_ids, dim=0) for gt_instance_ids in gt_instance_ids_list] targets = multi_apply(preprocess_video_panoptic_gt, gt_labels_list, gt_masks_list, gt_semantic_segs, gt_instance_ids_list, num_things_list, num_stuff_list) else: gt_labels_list = [ gt_instances['labels'] for gt_instances in batch_gt_instances ] gt_masks_list = [ gt_instances['masks'] for gt_instances in batch_gt_instances ] gt_semantic_segs = [ None if gt_semantic_seg is None else gt_semantic_seg.sem_seg for gt_semantic_seg in batch_gt_semantic_segs ] targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, gt_masks_list, gt_semantic_segs, num_things_list, num_stuff_list) labels, masks = targets batch_gt_instances = [ InstanceData(labels=label, masks=mask) for label, mask in zip(labels, masks) ] return batch_gt_instances def get_queries(self, batch_data_samples): img_size = batch_data_samples[0].batch_input_shape query_feat_list = [] bp_list = [] for idx, data_sample in enumerate(batch_data_samples): is_box = data_sample.gt_instances.bp.eq(0) is_point = data_sample.gt_instances.bp.eq(1) assert is_box.any() sparse_embed, _ = self.pe( data_sample.gt_instances[is_box], image_size=img_size, with_bboxes=True, with_points=False, ) sparse_embed = [sparse_embed] if is_point.any(): _sparse_embed, _ = self.pe( data_sample.gt_instances[is_point], image_size=img_size, with_bboxes=False, with_points=True, ) sparse_embed.append(_sparse_embed) sparse_embed = torch.cat(sparse_embed) assert len(sparse_embed) == len(data_sample.gt_instances) query_feat_list.append(self.query_proj(sparse_embed.flatten(1, 2))) bp_list.append(data_sample.gt_instances.bp) query_feat = torch.stack(query_feat_list) bp_labels = torch.stack(bp_list).to(dtype=torch.long) bp_embed = self.bp_embedding.weight[bp_labels] bp_embed = bp_embed.repeat_interleave(self.num_mask_tokens, dim=1) query_feat = query_feat + bp_embed return query_feat, None def get_targets( self, cls_scores_list: List[Tensor], mask_preds_list: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], return_sampling_results: bool = False ) -> Tuple[List[Union[Tensor, int]]]: """Compute classification and mask targets for all images for a decoder layer. Args: cls_scores_list (list[Tensor]): Mask score logits from a single decoder layer for all images. Each with shape (num_queries, cls_out_channels). mask_preds_list (list[Tensor]): Mask logits from a single decoder layer for all images. Each with shape (num_queries, h, w). batch_gt_instances (list[obj:`InstanceData`]): each contains ``labels`` and ``masks``. batch_img_metas (list[dict]): List of image meta information. return_sampling_results (bool): Whether to return the sampling results. Defaults to False. Returns: tuple: a tuple containing the following targets. - labels_list (list[Tensor]): Labels of all images.\ Each with shape (num_queries, ). - label_weights_list (list[Tensor]): Label weights\ of all images. Each with shape (num_queries, ). - mask_targets_list (list[Tensor]): Mask targets of\ all images. Each with shape (num_queries, h, w). - mask_weights_list (list[Tensor]): Mask weights of\ all images. Each with shape (num_queries, ). - avg_factor (int): Average factor that is used to average\ the loss. When using sampling method, avg_factor is usually the sum of positive and negative priors. When using `MaskPseudoSampler`, `avg_factor` is usually equal to the number of positive priors. additional_returns: This function enables user-defined returns from `self._get_targets_single`. These returns are currently refined to properties at each feature map (i.e. having HxW dimension). The results will be concatenated after the end. """ results = multi_apply( self._get_targets_single, cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas ) labels_list, label_weights_list, mask_targets_list, mask_weights_list, \ pos_inds_list, neg_inds_list, sampling_results_list = results[:7] rest_results = list(results[7:]) avg_factor = sum([results.avg_factor for results in sampling_results_list]) res = (labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor) if return_sampling_results: res = res + sampling_results_list return res + tuple(rest_results) def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, gt_instances: InstanceData, img_meta: dict) -> Tuple[Tensor]: """Compute classification and mask targets for one image. Args: cls_score (Tensor): Mask score logits from a single decoder layer for one image. Shape (num_queries, cls_out_channels). mask_pred (Tensor): Mask logits for a single decoder layer for one image. Shape (num_queries, h, w). gt_instances (:obj:`InstanceData`): It contains ``labels`` and ``masks``. img_meta (dict): Image informtation. Returns: tuple[Tensor]: A tuple containing the following for one image. - labels (Tensor): Labels of each image. \ shape (num_queries, ). - label_weights (Tensor): Label weights of each image. \ shape (num_queries, ). - mask_targets (Tensor): Mask targets of each image. \ shape (num_queries, h, w). - mask_weights (Tensor): Mask weights of each image. \ shape (num_queries, ). - pos_inds (Tensor): Sampled positive indices for each \ image. - neg_inds (Tensor): Sampled negative indices for each \ image. - sampling_result (:obj:`SamplingResult`): Sampling results. """ gt_labels = gt_instances.labels gt_masks = gt_instances.masks # sample points num_queries = cls_score.shape[0] num_gts = gt_labels.shape[0] if not self.matching_whole_map: point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) # shape (num_queries, num_points) mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) # shape (num_gts, num_points) gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) else: mask_points_pred = mask_pred gt_points_masks = gt_masks sampled_gt_instances = InstanceData( labels=gt_labels, masks=gt_points_masks) sampled_pred_instances = InstanceData( scores=cls_score, masks=mask_points_pred) # assign and sample assign_result = self.assigner.assign( pred_instances=sampled_pred_instances, gt_instances=sampled_gt_instances, img_meta=img_meta ) pred_instances = InstanceData(scores=cls_score, masks=mask_pred) sampling_result = self.sampler.sample( assign_result=assign_result, pred_instances=pred_instances, gt_instances=gt_instances) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds # label target labels = gt_labels.new_full((num_queries,), self.num_classes, dtype=torch.long) labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] label_weights = gt_labels.new_ones((num_queries,)) # mask target mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] mask_weights = mask_pred.new_zeros((num_queries,)) mask_weights[pos_inds] = 1.0 return labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds, sampling_result def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor, batch_gt_instances: List[InstanceData], batch_img_metas: List[dict]) -> Dict[str, Tensor]: """Loss function. Args: all_cls_scores (Tensor): Classification scores for all decoder layers with shape (num_decoder, batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. all_mask_preds (Tensor): Mask scores for all decoder layers with shape (num_decoder, batch_size, num_queries, h, w). batch_gt_instances (list[obj:`InstanceData`]): each contains ``labels`` and ``masks``. batch_img_metas (list[dict]): List of image meta information. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_dec_layers = len(all_cls_scores) batch_gt_instances_list = [ batch_gt_instances for _ in range(num_dec_layers) ] img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] losses_cls, losses_mask, losses_dice = multi_apply( self._loss_by_feat_single, all_cls_scores, all_mask_preds, batch_gt_instances_list, img_metas_list ) loss_dict = dict() # loss from other decoder layers num_dec_layer = 0 for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i num_dec_layer += 1 # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_mask'] = losses_mask[-1] loss_dict['loss_dice'] = losses_dice[-1] return loss_dict def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, batch_gt_instances: List[InstanceData], batch_img_metas: List[dict]) -> Tuple[Tensor]: """Loss function for outputs from a single decoder layer. Args: cls_scores (Tensor): Mask score logits from a single decoder layer for all images. Shape (batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. mask_preds (Tensor): Mask logits for a pixel decoder for all images. Shape (batch_size, num_queries, h, w). batch_gt_instances (list[obj:`InstanceData`]): each contains ``labels`` and ``masks``. batch_img_metas (list[dict]): List of image meta information. Returns: tuple[Tensor]: Loss components for outputs from a single \ decoder layer. """ batch_size, num_ins = cls_scores.size(0), cls_scores.size(1) # hack here: is_sam = num_ins != self.num_queries if not is_sam: cls_scores_list = [cls_scores[i] for i in range(batch_size)] mask_preds_list = [mask_preds[i] for i in range(batch_size)] labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor = \ self.get_targets(cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas) labels = torch.stack(labels_list, dim=0) label_weights = torch.stack(label_weights_list, dim=0) mask_targets = torch.cat(mask_targets_list, dim=0) mask_weights = torch.stack(mask_weights_list, dim=0) else: labels = torch.stack([item.labels for item in batch_gt_instances]) label_weights = labels.new_ones((batch_size, num_ins), dtype=torch.float) mask_targets = torch.cat([item.masks for item in batch_gt_instances]) mask_weights = mask_targets.new_ones((batch_size, num_ins), dtype=torch.float) avg_factor = cls_scores.size(1) # classification loss # shape (batch_size * num_queries, ) cls_scores = cls_scores.flatten(0, 1) labels = labels.flatten(0, 1) label_weights = label_weights.flatten(0, 1) class_weight = cls_scores.new_tensor(self.class_weight) ignore_inds = labels.eq(-1.) # zero will not be involved in the loss cal labels[ignore_inds] = 0 label_weights[ignore_inds] = 0. obj_inds = labels.eq(self.num_classes) if is_sam: cls_avg_factor = cls_scores.new_tensor([0]) else: cls_avg_factor = class_weight[labels].sum() cls_avg_factor = reduce_mean(cls_avg_factor) cls_avg_factor = max(cls_avg_factor, 1) if self.loss_iou is not None: loss_cls = self.loss_cls( cls_scores[..., :-1], labels, label_weights, avg_factor=cls_avg_factor ) loss_iou = self.loss_iou( cls_scores[..., -1:], obj_inds.to(dtype=torch.long), avg_factor=cls_avg_factor ) if is_sam: loss_iou = loss_iou * 0 loss_cls = loss_cls + loss_iou else: loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor ) # loss_mask num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) num_total_masks = max(num_total_masks, 1) # extract positive ones # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) mask_preds = mask_preds[mask_weights > 0] if mask_targets.shape[0] == 0: # zero match loss_dice = mask_preds.sum() loss_mask = mask_preds.sum() return loss_cls, loss_mask, loss_dice if not self.matching_whole_map: with torch.no_grad(): points_coords = get_uncertain_point_coords_with_randomness( mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio) # shape (num_total_gts, h, w) -> (num_total_gts, num_points) mask_point_targets = point_sample( mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) # shape (num_queries, h, w) -> (num_queries, num_points) mask_point_preds = point_sample( mask_preds.unsqueeze(1), points_coords).squeeze(1) else: mask_point_targets = mask_targets mask_point_preds = mask_preds # dice loss loss_dice = self.loss_dice( mask_point_preds, mask_point_targets, avg_factor=num_total_masks) # mask loss # shape (num_queries, num_points) -> (num_queries * num_points, ) mask_point_preds = mask_point_preds.reshape(-1) # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) mask_point_targets = mask_point_targets.reshape(-1) loss_mask = self.loss_mask( mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points ) return loss_cls, loss_mask, loss_dice def forward_logit(self, cls_embd): cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) cls_pred = cls_pred.max(-1).values cls_pred = self.logit_scale.exp() * cls_pred return cls_pred def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, attn_mask_target_size: Tuple[int, int], num_frames: int = 0) -> Tuple[Tensor]: """Forward for head part which is called after every decoder layer. Args: decoder_out (Tensor): in shape (batch_size, num_queries, c). mask_feature (Tensor): in shape (batch_size, c, h, w). attn_mask_target_size (tuple[int, int]): target attention mask size. Returns: tuple: A tuple contain three elements. - cls_pred (Tensor): Classification scores in shape \ (batch_size, num_queries, cls_out_channels). \ Note `cls_out_channels` should includes background. - mask_pred (Tensor): Mask scores in shape \ (batch_size, num_queries,h, w). - attn_mask (Tensor): Attention mask in shape \ (batch_size * num_heads, num_queries, h, w). - num_frames: How many frames are there in video. """ decoder_out = self.transformer_decoder.post_norm(decoder_out) # shape (num_queries, batch_size, c) if isinstance(self.cls_embed, nn.Module): cls_pred = self.cls_embed(decoder_out) # shape (num_queries, batch_size, c) mask_embed = self.mask_embed(decoder_out) # shape (num_queries, batch_size, h, w) mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) if not isinstance(self.cls_embed, nn.Module): maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach()) maskpool_embd = self.mask_pooling_proj(maskpool_embd) cls_embd = self.cls_proj(maskpool_embd + decoder_out) cls_pred = self.forward_logit(cls_embd) if self.iou_embed is not None: iou_pred = self.iou_embed(decoder_out) cls_pred = torch.cat([cls_pred, iou_pred], dim=-1) if num_frames > 0: assert len(mask_pred.shape) == 4 assert mask_pred.shape[2] % num_frames == 0 frame_h = mask_pred.shape[2] // num_frames num_q = mask_pred.shape[1] _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) attn_mask = F.interpolate( _mask_pred, attn_mask_target_size, mode='bilinear', align_corners=False) attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) else: attn_mask = F.interpolate( mask_pred, attn_mask_target_size, mode='bilinear', align_corners=False) # shape (num_queries, batch_size, h, w) -> # (batch_size * num_head, num_queries, h, w) attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( (1, self.num_heads, 1, 1)).flatten(0, 1) attn_mask = attn_mask.sigmoid() < 0.5 attn_mask = attn_mask.detach() return cls_pred, mask_pred, attn_mask def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]: """Forward function. Args: x (list[Tensor]): Multi scale Features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: tuple[list[Tensor]]: A tuple contains two elements. - cls_pred_list (list[Tensor)]: Classification logits \ for each decoder layer. Each is a 3D-tensor with shape \ (batch_size, num_queries, cls_out_channels). \ Note `cls_out_channels` should includes background. - mask_pred_list (list[Tensor]): Mask logits for each \ decoder layer. Each with shape (batch_size, num_queries, \ h, w). """ batch_img_metas = [] if isinstance(batch_data_samples[0], TrackDataSample): for track_sample in batch_data_samples: cur_list = [] for det_sample in track_sample: cur_list.append(det_sample.metainfo) batch_img_metas.append(cur_list) num_frames = len(batch_img_metas[0]) else: for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) num_frames = 0 batch_size = len(batch_img_metas) mask_features, multi_scale_memorys = self.pixel_decoder(x) if num_frames > 0: mask_features = mask_features.unflatten(0, (batch_size, num_frames)) mask_features = mask_features.transpose(1, 2).flatten(2, 3) decoder_inputs = [] decoder_positional_encodings = [] for i in range(self.num_transformer_feat_level): decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) # shape (batch_size, c, h, w) -> (batch_size, h*w, c) decoder_input = decoder_input.flatten(2).permute(0, 2, 1) if num_frames > 0: decoder_input = decoder_input.unflatten(0, (batch_size, num_frames)) decoder_input = decoder_input.flatten(1, 2) level_embed = self.level_embed.weight[i].view(1, 1, -1) decoder_input = decoder_input + level_embed # shape (batch_size, c, h, w) -> (batch_size, h*w, c) num_frames_real = 1 if num_frames == 0 else num_frames mask = decoder_input.new_zeros( (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) decoder_positional_encoding = self.decoder_positional_encoding( mask) decoder_positional_encoding = decoder_positional_encoding.transpose( 1, 2).flatten(2).permute(0, 2, 1) decoder_inputs.append(decoder_input) decoder_positional_encodings.append(decoder_positional_encoding) if self.enable_box_query and batch_data_samples[0].data_tag in ['sam_mul', 'sam']: query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo(batch_data_samples) query_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) query_embed = self.pos_linear(query_embed) else: # coco style query generation # shape (num_queries, c) -> (batch_size, num_queries, c) query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1)) query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1)) self_attn_mask = None cls_pred_list = [] mask_pred_list = [] cls_pred, mask_pred, attn_mask = self._forward_head( query_feat, mask_features, multi_scale_memorys[0].shape[-2:], num_frames=num_frames ) cls_pred_list.append(cls_pred) if num_frames > 0: mask_pred = mask_pred.unflatten(2, (num_frames, -1)) mask_pred_list.append(mask_pred) for i in range(self.num_transformer_decoder_layers): level_idx = i % self.num_transformer_feat_level # if a mask is all True(all background), then set it all False. attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # cross_attn + self_attn layer = self.transformer_decoder.layers[i] query_feat = layer( query=query_feat, key=decoder_inputs[level_idx], value=decoder_inputs[level_idx], query_pos=query_embed, key_pos=decoder_positional_encodings[level_idx], cross_attn_mask=attn_mask, self_attn_mask=self_attn_mask, query_key_padding_mask=None, # here we do not apply masking on padded region key_padding_mask=None) cls_pred, mask_pred, attn_mask = self._forward_head( query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], num_frames=num_frames ) cls_pred_list.append(cls_pred) if num_frames > 0: mask_pred = mask_pred.unflatten(2, (num_frames, -1)) mask_pred_list.append(mask_pred) return cls_pred_list, mask_pred_list, query_feat def loss( self, x: Tuple[Tensor], batch_data_samples: SampleList, ) -> Dict[str, Tensor]: """Perform forward propagation and loss calculation of the panoptic head on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict[str, Tensor]: a dictionary of loss components """ batch_img_metas = [] batch_gt_instances = [] batch_gt_semantic_segs = [] for data_sample in batch_data_samples: if isinstance(data_sample, TrackDataSample): clip_meta = [] clip_instances = [] clip_sem_seg = [] for det_sample in data_sample: clip_meta.append(det_sample.metainfo) clip_instances.append(det_sample.gt_instances) if 'gt_sem_seg' in det_sample: clip_sem_seg.append(det_sample.gt_sem_seg) else: clip_sem_seg.append(None) batch_img_metas.append(clip_meta) batch_gt_instances.append(clip_instances) batch_gt_semantic_segs.append(clip_sem_seg) else: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) if 'gt_sem_seg' in data_sample: batch_gt_semantic_segs.append(data_sample.gt_sem_seg) else: batch_gt_semantic_segs.append(None) # forward all_cls_scores, all_mask_preds, _ = self(x, batch_data_samples) # preprocess ground truth if not self.enable_box_query or batch_data_samples[0].data_tag in ['coco', 'sam']: batch_gt_instances = self.preprocess_gt(batch_gt_instances, batch_gt_semantic_segs) # loss if isinstance(batch_data_samples[0], TrackDataSample): num_frames = len(batch_img_metas[0]) all_mask_preds = [mask.flatten(2, 3) for mask in all_mask_preds] for instance in batch_gt_instances: instance['masks'] = instance['masks'].flatten(1, 2) film_metas = [ { 'img_shape': (meta[0]['img_shape'][0] * num_frames, meta[0]['img_shape'][1]) } for meta in batch_img_metas ] batch_img_metas = film_metas losses = self.loss_by_feat(all_cls_scores, all_mask_preds, batch_gt_instances, batch_img_metas) if self.enable_box_query: losses['loss_zero'] = 0 * self.query_feat.weight.sum() + 0 * self.query_embed.weight.sum() losses['loss_zero'] += 0 * self.pb_embedding.weight.sum() losses['loss_zero'] += 0 * self.mask_tokens.weight.sum() for name, param in self.pos_linear.named_parameters(): losses['loss_zero'] += 0 * param.sum() return losses def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, return_query=False, ) -> Tuple[Tensor, ...]: """Test without augmentaton. Args: return_query: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: tuple[Tensor]: A tuple contains two tensors. - mask_cls_results (Tensor): Mask classification logits,\ shape (batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. - mask_pred_results (Tensor): Mask logits, shape \ (batch_size, num_queries, h, w). """ data_sample = batch_data_samples[0] if isinstance(data_sample, TrackDataSample): img_shape = data_sample[0].metainfo['batch_input_shape'] num_frames = len(data_sample) else: img_shape = data_sample.metainfo['batch_input_shape'] num_frames = 0 all_cls_scores, all_mask_preds, query_feat = self(x, batch_data_samples) if self.iou_embed is not None: _all_cls_scores = [cls_score[..., :-1] for cls_score in all_cls_scores] iou_results = [cls_score[..., -1:] for cls_score in all_cls_scores] all_cls_scores = _all_cls_scores else: iou_results = None mask_cls_results = all_cls_scores[-1] mask_pred_results = all_mask_preds[-1] if iou_results is not None: iou_results = iou_results[-1] if num_frames > 0: mask_pred_results = mask_pred_results.flatten(1, 2) mask_pred_results = F.interpolate( mask_pred_results, size=(img_shape[0], img_shape[1]), mode='bilinear', align_corners=False) if num_frames > 0: num_queries = mask_cls_results.shape[1] mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames)) if iou_results is None: return mask_cls_results, mask_pred_results if return_query: return mask_cls_results, mask_pred_results, query_feat, iou_results else: return mask_cls_results, mask_pred_results, iou_results def prepare_for_dn_mo(self, batch_data_samples): scalar, noise_scale = 100, 0.4 gt_instances = [t.gt_instances for t in batch_data_samples] point_coords = torch.stack([inst.point_coords for inst in gt_instances]) pb_labels = torch.stack([inst['bp'] for inst in gt_instances]) labels = torch.zeros_like(pb_labels).long() boxes = point_coords # + boxes factors = [] for i, data_sample in enumerate(batch_data_samples): h, w, = data_sample.metainfo['img_shape'] factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1) factors.append(factor) factors = torch.stack(factors, 0) boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ???? # box_start = [t['box_start'] for t in targets] box_start = [len(point) for point in point_coords] known_labels = labels known_pb_labels = pb_labels known_bboxs = boxes known_labels_expaned = known_labels.clone() known_pb_labels_expaned = known_pb_labels.clone() known_bbox_expand = known_bboxs.clone() if noise_scale > 0 and self.training: diff = torch.zeros_like(known_bbox_expand) diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 diff[:, :, 2:] = known_bbox_expand[:, :, 2:] # add very small noise to input points; no box sc = 0.01 for i, st in enumerate(box_start): diff[i, :st] = diff[i, :st] * sc known_bbox_expand += torch.mul( (torch.rand_like(known_bbox_expand) * 2 - 1.0), diff) * noise_scale known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) input_label_embed = self.pb_embedding(known_pb_labels_expaned) input_bbox_embed = inverse_sigmoid(known_bbox_expand) input_label_embed = input_label_embed.repeat_interleave( self.num_mask_tokens, 1) + self.mask_tokens.weight.unsqueeze(0).repeat( input_label_embed.shape[0], input_label_embed.shape[1], 1) input_bbox_embed = input_bbox_embed.repeat_interleave( self.num_mask_tokens, 1) single_pad = self.num_mask_tokens # NOTE scalar is modified to 100, each click cannot see each other scalar = int(input_label_embed.shape[1] / self.num_mask_tokens) pad_size = input_label_embed.shape[1] if input_label_embed.shape[1] > 0: input_query_label = input_label_embed input_query_bbox = input_bbox_embed tgt_size = pad_size attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(scalar): if i == 0: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True if i == scalar - 1: attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True else: attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True mask_dict = { 'known_lbs_bboxes': (known_labels, known_bboxs), 'pad_size': pad_size, 'scalar': scalar, } return input_query_label, input_query_bbox, attn_mask, mask_dict