3DTopia-XL / dva /visualize.py
FrozenBurning
single view to 3D init release
81ecb2b
raw
history blame contribute delete
No virus
18.8 kB
import cv2
import os
import numpy as np
import torch
import imageio
from torchvision.utils import make_grid, save_image
from .ray_marcher import RayMarcher, generate_colored_boxes
def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])):
num_points = angles.shape[0]
x = radius * torch.cos(angles)
h = torch.ones((num_points,)) * height
z = radius * torch.sin(angles)
position = torch.stack([x, h, z], dim=-1)
forward = position / torch.norm(position, p=2, dim=-1, keepdim=True)
right = -torch.cross(world_up[None, ...], forward)
right /= torch.norm(right, dim=-1, keepdim=True)
up = torch.cross(forward, right)
up /= torch.norm(up, p=2, dim=-1, keepdim=True)
rotation = torch.stack([right, up, forward], dim=1)
translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1)
return torch.concat([rotation, translation], dim=2)
def render_mvp_boxes(rm, batch, preds):
with torch.no_grad():
boxes_rgba = generate_colored_boxes(
preds["prim_rgba"],
preds["prim_rot"],
)
preds_boxes = rm(
prim_rgba=boxes_rgba,
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
return preds_boxes["rgba_image"][:, :3].permute(0, 2, 3, 1)
def save_image_summary(path, batch, preds):
rgb = preds["rgb"].detach().permute(0, 3, 1, 2)
# rgb_gt = batch["image"]
rgb_boxes = preds["rgb_boxes"].detach().permute(0, 3, 1, 2)
bs = rgb_boxes.shape[0]
if "folder" in batch and "key" in batch:
obj_list = []
for bs_idx in range(bs):
tmp_img = rgb_boxes[bs_idx].permute(1, 2, 0).to(torch.uint8).cpu().numpy()
tmp_img = np.ascontiguousarray(tmp_img)
folder = batch['folder'][bs_idx]
key = batch['key'][bs_idx]
obj_list.append("{}/{}\n".format(folder, key))
cv2.putText(tmp_img, "{}".format(folder), (200, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2)
cv2.putText(tmp_img, "{}".format(key), (200, 400), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2)
tmp_img_torch = torch.as_tensor(tmp_img).permute(2, 0, 1).float()
rgb_boxes[bs_idx] = tmp_img_torch
with open(os.path.splitext(path)[0]+".txt", "w") as f:
f.writelines(obj_list)
img = make_grid(torch.cat([rgb, rgb_boxes], dim=2) / 255.0).clip(0.0, 1.0)
save_image(img, path)
@torch.no_grad()
def visualize_primsdf_box(image_save_path, model, rm: RayMarcher, device):
# prim_rgba: primitive payload [B, K, 4, S, S, S],
# K - # of primitives, S - primitive size
# prim_pos: locations [B, K, 3]
# prim_rot: rotations [B, K, 3, 3]
# prim_scale: scales [B, K, 3]
# K: intrinsics [B, 3, 3]
# RT: extrinsics [B, 3, 4]
preds = {}
batch = {}
prim_alpha = model.sdf2alpha(model.feat_geo).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255
prim_rgb = model.feat_tex.reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1)
preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3))
batch['Rt'] = torch.Tensor([
[
1.0,
0.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
-1.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
0.0,
-1.0,
5 * rm.volradius
]
]).to(device)[None, ...]
batch['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(device)[None, ...]
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
batch['K'][:, 0:1, :] *= ratio_h
batch['K'][:, 1:2, :] *= ratio_w
# raymarcher is in mm
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous())
with torch.no_grad():
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds)
save_image_summary(image_save_path, batch, preds)
@torch.no_grad()
def render_primsdf(image_save_path, model, rm, device):
preds = {}
batch = {}
preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1)
preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3))
batch['Rt'] = torch.Tensor([
[
1.0,
0.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
-1.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
0.0,
-1.0,
5 * rm.volradius
]
]).to(device)[None, ...]
batch['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(device)[None, ...]
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
batch['K'][:, 0:1, :] *= ratio_h
batch['K'][:, 1:2, :] *= ratio_w
# test rendering
all_sampled_sdf = []
all_sampled_tex = []
for i in range(model.prim_shape ** 3):
with torch.no_grad():
model_prediction = model(model.sdf_sampled_point[:, i, :].to(device))
sampled_sdf = model_prediction['sdf']
sampled_rgb = model_prediction['tex']
all_sampled_sdf.append(sampled_sdf)
all_sampled_tex.append(sampled_rgb)
sampled_sdf = torch.stack(all_sampled_sdf, dim=1)
sampled_tex = torch.stack(all_sampled_tex, dim=1).permute(0, 2, 1).reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255
prim_rgb = sampled_tex
prim_alpha = model.sdf2alpha(sampled_sdf).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous())
with torch.no_grad():
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds)
save_image_summary(image_save_path, batch, preds)
@torch.no_grad()
def visualize_primvolume(image_save_path, batch, prim_volume, rm: RayMarcher, device):
# prim_volume - [B, nprims, 4+6*8^3]
def sdf2alpha(sdf):
return torch.exp(-(sdf / 0.005) ** 2)
preds = {}
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3)))
num_prims = prim_volume.shape[1]
bs = prim_volume.shape[0]
geo_start_index = 4
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive
tex_start_index = geo_end_index
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive
mat_start_index = tex_end_index
mat_end_index = mat_start_index + prim_shape ** 3 * 2
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index]
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index]
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
pos = prim_volume[:, :, 1:4]
scale = prim_volume[:, :, 0:1]
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1)
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3))
batch['Rt'] = torch.Tensor([
[
1.0,
0.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
-1.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
0.0,
-1.0,
5 * rm.volradius
]
]).to(device)[None, ...].repeat(bs, 1, 1)
batch['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(device)[None, ...].repeat(bs, 1, 1)
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
batch['K'][:, 0:1, :] *= ratio_h
batch['K'][:, 1:2, :] *= ratio_w
# raymarcher is in mm
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous())
with torch.no_grad():
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds)
save_image_summary(image_save_path, batch, preds)
@torch.no_grad()
def visualize_multiview_primvolume(image_save_path, batch, prim_volume, view_counts, rm: RayMarcher, device):
# prim_volume - [B, nprims, 4+6*8^3]
view_angles = torch.linspace(0.5, 2.5, view_counts + 1) * torch.pi
view_angles = view_angles[:-1]
def sdf2alpha(sdf):
return torch.exp(-(sdf / 0.005) ** 2)
preds = {}
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3)))
num_prims = prim_volume.shape[1]
bs = prim_volume.shape[0]
geo_start_index = 4
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive
tex_start_index = geo_end_index
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive
mat_start_index = tex_end_index
mat_end_index = mat_start_index + prim_shape ** 3 * 2
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index]
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index]
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
pos = prim_volume[:, :, 1:4]
scale = prim_volume[:, :, 0:1]
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1)
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3))
batch['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(device)[None, ...].repeat(bs, 1, 1)
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
batch['K'][:, 0:1, :] *= ratio_h
batch['K'][:, 1:2, :] *= ratio_w
final_preds = {}
final_preds['rgb'] = []
final_preds['rgb_boxes'] = []
for view_ang in view_angles:
bs_view_ang = view_ang.repeat(bs,)
batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume)
# raymarcher is in mm
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous())
with torch.no_grad():
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds)
final_preds['rgb'].append(preds['rgb'])
final_preds['rgb_boxes'].append(preds['rgb_boxes'])
final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0)
final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0)
save_image_summary(image_save_path, batch, final_preds)
@torch.no_grad()
def visualize_video_primvolume(video_save_folder, batch, prim_volume, view_counts, rm: RayMarcher, device):
# prim_volume - [B, nprims, 4+6*8^3]
view_angles = torch.linspace(1.5, 3.5, view_counts + 1) * torch.pi
def sdf2alpha(sdf):
return torch.exp(-(sdf / 0.005) ** 2)
preds = {}
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3)))
num_prims = prim_volume.shape[1]
bs = prim_volume.shape[0]
geo_start_index = 4
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive
tex_start_index = geo_end_index
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive
mat_start_index = tex_end_index
mat_end_index = mat_start_index + prim_shape ** 3 * 2
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index]
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index]
feat_mat = prim_volume[:, :, mat_start_index: mat_end_index]
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255
prim_mat = feat_mat.reshape(bs, num_prims, 2, prim_shape, prim_shape, prim_shape) * 255
dummy_prim = torch.zeros_like(prim_mat[:, :, 0:1, ...])
prim_mat = torch.concat([dummy_prim, prim_mat], dim=2)
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
preds['prim_mata'] = torch.concat([prim_mat, prim_alpha], dim=2)
pos = prim_volume[:, :, 1:4]
scale = prim_volume[:, :, 0:1]
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1)
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3))
batch['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(device)[None, ...].repeat(bs, 1, 1)
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
batch['K'][:, 0:1, :] *= ratio_h
batch['K'][:, 1:2, :] *= ratio_w
final_preds = {}
final_preds['rgb'] = []
final_preds['rgb_boxes'] = []
final_preds['mat_rgb'] = []
for view_ang in view_angles:
bs_view_ang = view_ang.repeat(bs,)
batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume)
# raymarcher is in mm
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous())
with torch.no_grad():
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds)
rm_preds = rm(
prim_rgba=preds["prim_mata"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=batch["Rt"],
K=batch["K"],
)
mat_rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1)
preds.update(mat_rgb=mat_rgba[..., :3].contiguous())
final_preds['rgb'].append(preds['rgb'])
final_preds['rgb_boxes'].append(preds['rgb_boxes'])
final_preds['mat_rgb'].append(preds['mat_rgb'])
assert len(final_preds['rgb']) == len(final_preds['rgb_boxes'])
final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0)
final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0)
final_preds['mat_rgb'] = torch.concat(final_preds['mat_rgb'], dim=0)
total_num_frames = final_preds['rgb'].shape[0]
rgb_video = os.path.join(video_save_folder, 'rgb.mp4')
rgb_video_out = imageio.get_writer(rgb_video, fps=20)
prim_video = os.path.join(video_save_folder, 'prim.mp4')
prim_video_out = imageio.get_writer(prim_video, fps=20)
mat_video = os.path.join(video_save_folder, 'mat.mp4')
mat_video_out = imageio.get_writer(mat_video, fps=20)
rgb_np = np.clip(final_preds['rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8)
prim_np = np.clip(final_preds['rgb_boxes'].detach().cpu().numpy(), 0, 255).astype(np.uint8)
mat_np = np.clip(final_preds['mat_rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8)
for fidx in range(total_num_frames):
rgb_video_out.append_data(rgb_np[fidx])
prim_video_out.append_data(prim_np[fidx])
mat_video_out.append_data(mat_np[fidx])
rgb_video_out.close()
prim_video_out.close()
mat_video_out.close()