File size: 3,067 Bytes
4b532c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch import nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])


class RoIPointPool3d(nn.Module):
    """Encode the geometry-specific features of each 3D proposal.



    Please refer to `Paper of PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_

    for more details.



    Args:

        num_sampled_points (int, optional): Number of samples in each roi.

            Default: 512.

    """

    def __init__(self, num_sampled_points=512):
        super().__init__()
        self.num_sampled_points = num_sampled_points

    def forward(self, points, point_features, boxes3d):
        """

        Args:

            points (torch.Tensor): Input points whose shape is (B, N, C).

            point_features (torch.Tensor): Features of input points whose shape

                is (B, N, C).

            boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).



        Returns:

            pooled_features (torch.Tensor): The output pooled features whose

                shape is (B, M, 512, 3 + C).

            pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).

        """
        return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
                                            self.num_sampled_points)


class RoIPointPool3dFunction(Function):

    @staticmethod
    def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
        """

        Args:

            points (torch.Tensor): Input points whose shape is (B, N, C).

            point_features (torch.Tensor): Features of input points whose shape

                is (B, N, C).

            boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).

            num_sampled_points (int, optional): The num of sampled points.

                Default: 512.



        Returns:

            pooled_features (torch.Tensor): The output pooled features whose

                shape is (B, M, 512, 3 + C).

            pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).

        """
        assert len(points.shape) == 3 and points.shape[2] == 3
        batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
            1], point_features.shape[2]
        pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
        pooled_features = point_features.new_zeros(
            (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
        pooled_empty_flag = point_features.new_zeros(
            (batch_size, boxes_num)).int()

        ext_module.roipoint_pool3d_forward(points.contiguous(),
                                           pooled_boxes3d.contiguous(),
                                           point_features.contiguous(),
                                           pooled_features, pooled_empty_flag)

        return pooled_features, pooled_empty_flag

    @staticmethod
    def backward(ctx, grad_out):
        raise NotImplementedError