File size: 6,088 Bytes
81ecb2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import trimesh
import torch.nn as nn
import torch.nn.functional as F

import logging

logger = logging.getLogger(__name__)

class PrimSDF(nn.Module):
    def __init__(self, mesh_obj=None, f_sdf=None, geo_fn=None, asset_list=None, num_prims=1024, dim_feat=6, prim_shape=8, init_scale=0.05, sdf2alpha_var=0.005, auto_scale_init=True, init_sampling="uniform"):
        super().__init__()
        self.num_prims = num_prims
        # 6 channels features - [SDF, R, G, B, roughness, metallic]
        self.dim_feat = dim_feat
        self.prim_shape = prim_shape
        self.sdf_sampled_point = None
        self.auto_scale_init = auto_scale_init
        self.init_sampling = init_sampling
        self.sdf2alpha_var = sdf2alpha_var

        # assume the mesh is normalized to [-1, 1] cube
        self.mesh_obj = mesh_obj
        self.f_sdf = f_sdf
        # N x (D x S^3 + 3(Global Translation) + 1(Global Scale))
        self.srt_param = nn.parameter.Parameter(torch.zeros(self.num_prims, 1 + 3))
        self.feat_param = nn.parameter.Parameter(torch.zeros(self.num_prims, self.dim_feat * (self.prim_shape ** 3)))
        self.geo_start_index = 0
        self.geo_end_index = self.geo_start_index + self.prim_shape ** 3 # non-inclusive
        self.tex_start_index = self.geo_end_index
        self.tex_end_index = self.tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive
        self.mat_start_index = self.tex_end_index
        self.mat_end_index = self.mat_start_index + self.prim_shape ** 3 * 2

        # sampled_point -> local grid
        # local_grid - [prim_shape^3, 3]
        xx = torch.linspace(-1, 1, self.prim_shape)
        # two ways to sample xyz-axis aligned local grids: 1st is ij indexing
        meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='ij')
        local_grid = torch.stack((meshz, meshy, meshx), dim=-1).reshape(-1, 3)
        self.local_grid = local_grid
        # second is xy indexing, equivalent to the first one
        # meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='xy')
        # local_grid = torch.stack((meshz, meshx, meshy), dim=-1).reshape(-1, 3)
        if self.f_sdf is not None and geo_fn is not None and asset_list is not None:
            self._init_param(init_scale=init_scale, geo_fn=geo_fn, asset_list=asset_list, sampling=self.init_sampling)

    @torch.no_grad()
    def _init_param(self, init_scale, geo_fn, asset_list, sampling="uniform"):
        pass

    def forward(self, x):
        # x - [bs, 3]
        bs = x.shape[0]
        weights = self.prim_weight(x)
        output = self.grid_sample_feat(x, weights)
        preds = {}
        preds['sdf'] = output[:, 0:1]
        # RGB
        preds['tex'] = torch.clip(output[:, 1:4], min=0.0, max=1.0)
        # roughness, metallic
        preds['mat'] = torch.clip(output[:, 4:6], min=0.0, max=1.0)
        return preds
        
    def grid_sample_feat(self, x, weights):
        # implementation of I_V -> trilinear grid sample of V_i
        # x - [bs, 3]
        # weights - [bs, n_prims]
        bs = x.shape[0]
        sampled_point = (x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...]
        mask = weights > 0
        ind_bs, ind_nprim = torch.where(weights > 0)
        masked_sampled_point = sampled_point[ind_bs, ind_nprim, :].reshape(ind_nprim.shape[0], 1, 1, 1, 3)
        feat4sample = self.feat[ind_nprim, :].reshape(ind_nprim.shape[0], self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape)
        
        sampled_feat = F.grid_sample(feat4sample, masked_sampled_point, mode='bilinear', padding_mode='zeros', align_corners=True).reshape(ind_nprim.shape[0], self.dim_feat)
        weighted_sampled_feat = sampled_feat * weights[mask][:, None]
        weighted_feat = torch.zeros(bs, self.dim_feat).to(x)
        weighted_feat.index_add_(0, ind_bs, weighted_sampled_feat)

        # at inference time, fill in approximated SDF value for region not covered by prims
        if not self.training:
            # get mask for points not covered by prims
            bs_mask = weights.sum(1) <= 0

            # get nearest prim index
            dist = torch.norm(x[bs_mask, None, :] - self.pos[None, ...], p=2, dim=-1)
            _, min_dist_ind = dist.min(1)
            nearest_prim_pos = self.pos[min_dist_ind, :]
            nearest_prim_scale = self.scale[min_dist_ind, :]

            # in each nearest prim, get nearest voxel points
            candidate_nearest_pts = nearest_prim_pos[:, None, :] + nearest_prim_scale[..., None] * self.local_grid.to(x)[None, :]
            pts_dist = torch.norm(x[bs_mask, None, :] - candidate_nearest_pts, p=2, dim=-1)
            min_dist, min_dist_pts_ind = pts_dist.min(1)

            # get the SDF value as a nearest valid SDF value
            min_pts_sdf = self.feat_geo[min_dist_ind, min_dist_pts_ind]
            # approximate SDF value with the same sign distance + L2 distance
            approx_sdf = min_pts_sdf + min_dist * torch.sign(min_pts_sdf)
            weighted_feat[bs_mask, 0:1] = approx_sdf[:, None]            
        return weighted_feat
    
    def prim_weight(self, x):
        # x - [bs, 3]
        weights = F.relu(1 - torch.norm((x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...], p = float('inf'), dim=-1))
        # weight - [bs, N]
        normalized_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)
        return normalized_weights

    def sdf2alpha(self, sdf):
        return torch.exp(-(sdf / self.sdf2alpha_var) ** 2)

    @property
    def pos(self):
        return self.srt_param[:, 1:4]
    
    @property
    def scale(self):
        return self.srt_param[:, 0:1]
    
    @property
    def feat(self):
        return self.feat_param

    @property
    def feat_geo(self):
        return self.feat_param[:, self.geo_start_index:self.geo_end_index]
    
    @property
    def feat_tex(self):
        return self.feat_param[:, self.tex_start_index:self.tex_end_index]

    @property
    def feat_mat(self):
        return self.feat_param[:, self.mat_start_index:self.mat_end_index]