# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Dict, Tuple import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F import random from dva.mvp.extensions.mvpraymarch.mvpraymarch import mvpraymarch from dva.mvp.extensions.utils.utils import compute_raydirs import logging logger = logging.getLogger(__name__) def convert_camera_parameters(Rt, K): R = Rt[:, :3, :3] t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) return dict( campos=t, camrot=R, focal=K[:, :2, :2], princpt=K[:, :2, 2], ) def subsample_pixel_coords( pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 ): H, W = pixel_coords.shape[:2] SW = W // ray_subsample_factor SH = H // ray_subsample_factor all_coords = [] for _ in range(batch_size): # TODO: this is ugly, switch to pytorch? x0 = th.randint(0, ray_subsample_factor - 1, size=()) y0 = th.randint(0, ray_subsample_factor - 1, size=()) dx = ray_subsample_factor dy = ray_subsample_factor x1 = x0 + dx * SW y1 = y0 + dy * SH all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) all_coords = th.stack(all_coords, dim=0) return all_coords def resize_pixel_coords( pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 ): H, W = pixel_coords.shape[:2] SW = W // ray_subsample_factor SH = H // ray_subsample_factor all_coords = [] for _ in range(batch_size): # TODO: this is ugly, switch to pytorch? x0, y0 = ray_subsample_factor // 2, ray_subsample_factor // 2 dx = ray_subsample_factor dy = ray_subsample_factor x1 = x0 + dx * SW y1 = y0 + dy * SH all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) all_coords = th.stack(all_coords, dim=0) return all_coords class RayMarcher(nn.Module): def __init__( self, image_height, image_width, volradius, fadescale=8.0, fadeexp=8.0, dt=1.0, ray_subsample_factor=1, accum=2, termthresh=0.99, blocksize=None, with_t_img=True, chlast=False, assets=None, ): super().__init__() # TODO: add config? self.image_height = image_height self.image_width = image_width self.volradius = volradius self.dt = dt self.fadescale = fadescale self.fadeexp = fadeexp # NOTE: this seems to not work for other configs? if blocksize is None: blocksize = (8, 16) self.blocksize = blocksize self.with_t_img = with_t_img self.chlast = chlast self.accum = accum self.termthresh = termthresh base_pixel_coords = th.stack( th.meshgrid( th.arange(self.image_height, dtype=th.float32), th.arange(self.image_width, dtype=th.float32), )[::-1], dim=-1, ) self.register_buffer("base_pixel_coords", base_pixel_coords, persistent=False) self.fixed_bvh_cache = {-1: (th.empty(0), th.empty(0), th.empty(0))} self.ray_subsample_factor = ray_subsample_factor def _set_pix_coords(self): dev = self.base_pixel_coords.device self.base_pixel_coords = th.stack( th.meshgrid( th.arange(self.image_height, dtype=th.float32, device=dev), th.arange(self.image_width, dtype=th.float32, device=dev), )[::-1], dim=-1, ) def resize(self, h: int, w: int): self.image_height = h self.image_width = w self._set_pix_coords() def forward( self, prim_rgba: th.Tensor, prim_pos: th.Tensor, prim_rot: th.Tensor, prim_scale: th.Tensor, K: th.Tensor, RT: th.Tensor, ray_subsample_factor: Optional[int] = None, ): """ Args: prim_rgba: primitive payload [B, K, 4, S, S, S], K - # of primitives, S - primitive size prim_pos: locations [B, K, 3] prim_rot: rotations [B, K, 3, 3] prim_scale: scales [B, K, 3] K: intrinsics [B, 3, 3] RT: extrinsics [B, 3, 4] Returns: a dict of tensors """ # TODO: maybe we can re-use mvpraymarcher? B = prim_rgba.shape[0] device = prim_rgba.device # TODO: this should return focal 2x2? camera = convert_camera_parameters(RT, K) camera = {k: v.contiguous() for k, v in camera.items()} dt = self.dt / self.volradius if ray_subsample_factor is None: ray_subsample_factor = self.ray_subsample_factor if ray_subsample_factor > 1 and self.training: pixel_coords = subsample_pixel_coords( self.base_pixel_coords, int(B), ray_subsample_factor ) elif ray_subsample_factor > 1: pixel_coords = resize_pixel_coords( self.base_pixel_coords, int(B), ray_subsample_factor, ) else: pixel_coords = ( self.base_pixel_coords[np.newaxis].expand(B, -1, -1, -1).contiguous() ) prim_pos = prim_pos / self.volradius focal = th.diagonal(camera["focal"], dim1=1, dim2=2).contiguous() # TODO: port this? raypos, raydir, tminmax = compute_raydirs( viewpos=camera["campos"], viewrot=camera["camrot"], focal=focal, princpt=camera["princpt"], pixelcoords=pixel_coords, volradius=self.volradius, ) rgba = mvpraymarch( raypos, raydir, stepsize=dt, tminmax=tminmax, algo=0, template=prim_rgba.permute(0, 1, 3, 4, 5, 2).contiguous(), warp=None, termthresh=self.termthresh, primtransf=(prim_pos, prim_rot, prim_scale), fadescale=self.fadescale, fadeexp=self.fadeexp, usebvh="fixedorder", chlast=True, ) rgba = rgba.permute(0, 3, 1, 2) preds = { "rgba_image": rgba, "pixel_coords": pixel_coords, } return preds def generate_colored_boxes(template, prim_rot, alpha=10000.0, seed=123456): B = template.shape[0] output = template.clone() device = template.device lightdir = -3 * th.ones([B, 3], dtype=th.float32, device=device) lightdir = lightdir / th.norm(lightdir, p=2, dim=1, keepdim=True) zz, yy, xx = th.meshgrid( th.linspace(-1.0, 1.0, template.size(-1), device=device), th.linspace(-1.0, 1.0, template.size(-1), device=device), th.linspace(-1.0, 1.0, template.size(-1), device=device), ) primnormalx = th.where( (th.abs(xx) >= th.abs(yy)) & (th.abs(xx) >= th.abs(zz)), th.sign(xx) * th.ones_like(xx), th.zeros_like(xx), ) primnormaly = th.where( (th.abs(yy) >= th.abs(xx)) & (th.abs(yy) >= th.abs(zz)), th.sign(yy) * th.ones_like(xx), th.zeros_like(xx), ) primnormalz = th.where( (th.abs(zz) >= th.abs(xx)) & (th.abs(zz) >= th.abs(yy)), th.sign(zz) * th.ones_like(xx), th.zeros_like(xx), ) primnormal = th.stack([primnormalx, -primnormaly, -primnormalz], dim=-1) primnormal = primnormal / th.sqrt(th.sum(primnormal**2, dim=-1, keepdim=True)) output[:, :, 3, :, :, :] = alpha np.random.seed(seed) for i in range(template.size(1)): # generating a random color output[:, i, 0, :, :, :] = np.random.rand() * 255.0 output[:, i, 1, :, :, :] = np.random.rand() * 255.0 output[:, i, 2, :, :, :] = np.random.rand() * 255.0 # get light direction in local coordinate system? lightdir0 = lightdir mult = th.sum( lightdir0[:, None, None, None, :] * primnormal[np.newaxis], dim=-1 )[:, np.newaxis, :, :, :].clamp(min=0.2) output[:, i, :3, :, :, :] *= 1.4 * mult return output