File size: 3,362 Bytes
b34d1d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch


def preprocess_video_panoptic_gt(
        gt_labels,
        gt_masks,
        gt_semantic_seg,
        gt_instance_ids,
        num_things,
        num_stuff,
):
    num_classes = num_things + num_stuff
    num_frames = len(gt_masks)
    mask_size = gt_masks[0].masks.shape[-2:]

    thing_masks_list = []
    for frame_id in range(num_frames):
        thing_masks_list.append(gt_masks[frame_id].pad(
            mask_size, pad_val=0).to_tensor(
            dtype=torch.bool, device=gt_labels.device)
        )
    instances = torch.unique(gt_instance_ids[:, 1])
    things_masks = []
    labels = []
    for instance in instances:
        pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0]  # 0 is for redundant tuple
        labels_instance = gt_labels[:, 1][pos_ins]
        assert torch.allclose(labels_instance, labels_instance[0])
        labels.append(labels_instance[0])
        instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist()
        instance_masks = []
        for frame_id in range(num_frames):
            frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1]
            if frame_id not in instance_frame_ids:
                empty_mask = torch.zeros(
                    mask_size,
                    dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device
                )
                instance_masks.append(empty_mask)
            else:
                pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item()
                frame_mask = thing_masks_list[frame_id][pos_inner_frame]
                instance_masks.append(frame_mask)
        things_masks.append(torch.stack(instance_masks))

    if len(instances) == 0:
        things_masks = torch.stack(thing_masks_list, dim=1)
        labels = torch.empty_like(instances)
    else:
        things_masks = torch.stack(things_masks)
        labels = torch.stack(labels)
    assert torch.all(torch.less(labels, num_things))

    if gt_semantic_seg is not None:
        things_labels = labels
        gt_semantic_seg = gt_semantic_seg.squeeze(1)

        semantic_labels = torch.unique(
            gt_semantic_seg,
            sorted=False,
            return_inverse=False,
            return_counts=False)
        stuff_masks_list = []
        stuff_labels_list = []
        for label in semantic_labels:
            if label < num_things or label >= num_classes:
                continue
            stuff_mask = gt_semantic_seg == label
            stuff_masks_list.append(stuff_mask)
            stuff_labels_list.append(label)

        if len(stuff_masks_list) > 0:
            stuff_masks = torch.stack(stuff_masks_list, dim=0)
            stuff_labels = torch.stack(stuff_labels_list, dim=0)
            assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes))
            labels = torch.cat([things_labels, stuff_labels], dim=0)
            masks = torch.cat([things_masks, stuff_masks], dim=0)
        else:
            labels = things_labels
            masks = things_masks
        assert len(labels) == len(masks)
    else:
        masks = things_masks

    labels = labels.to(dtype=torch.long)
    masks = masks.to(dtype=torch.long)
    return labels, masks