Inpaint / src /util /alignment.py
ZehanWang's picture
Upload folder using huggingface_hub
864ec44 verified
raw
history blame
2.16 kB
# Author: Bingxin Ke
# Last modified: 2024-01-11
import numpy as np
import torch
def align_depth_least_square(
gt_arr: np.ndarray,
pred_arr: np.ndarray,
valid_mask_arr: np.ndarray,
return_scale_shift=True,
max_resolution=None,
):
ori_shape = pred_arr.shape # input shape
gt = gt_arr.squeeze() # [H, W]
pred = pred_arr.squeeze()
valid_mask = valid_mask_arr.squeeze()
# Downsample
if max_resolution is not None:
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
valid_mask = (
downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
.bool()
.numpy()
)
assert (
gt.shape == pred.shape == valid_mask.shape
), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
gt_masked = gt[valid_mask].reshape((-1, 1))
pred_masked = pred[valid_mask].reshape((-1, 1))
# numpy solver
_ones = np.ones_like(pred_masked)
A = np.concatenate([pred_masked, _ones], axis=-1)
X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
scale, shift = X
aligned_pred = pred_arr * scale + shift
# restore dimensions
aligned_pred = aligned_pred.reshape(ori_shape)
if return_scale_shift:
return aligned_pred, scale, shift
else:
return aligned_pred
# ******************** disparity space ********************
def depth2disparity(depth, return_mask=False):
if isinstance(depth, torch.Tensor):
disparity = torch.zeros_like(depth)
elif isinstance(depth, np.ndarray):
disparity = np.zeros_like(depth)
non_negtive_mask = depth > 0
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
if return_mask:
return disparity, non_negtive_mask
else:
return disparity
def disparity2depth(disparity, **kwargs):
return depth2disparity(disparity, **kwargs)