# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # Last modified: 2024-05-24 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import logging from diffusers.image_processor import VaeImageProcessor import pdb from typing import Dict, Optional, Union import PIL.Image import numpy as np import torch from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, LCMScheduler, PNDMScheduler, UNet2DConditionModel, ) from .duplicate_unet import DoubleUNet2DConditionModel from torch.nn import Conv2d from PIL import ImageDraw, ImageFont from torch.nn.parameter import Parameter from diffusers.utils import BaseOutput, make_image_grid from PIL import Image from torch.utils.data import DataLoader, TensorDataset from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import pil_to_tensor, resize from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from .util.batchsize import find_batch_size from .util.ensemble import ensemble_depth from .util.image_util import ( chw2hwc, colorize_depth_maps, get_tv_resample_method, resize_max_res, ) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class MarigoldDepthOutput(BaseOutput): """ Output class for Marigold monocular depth prediction pipeline. Args: depth_np (`np.ndarray`): Predicted depth map, with depth values in the range of [0, 1]. depth_colored (`PIL.Image.Image`): Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. uncertainty (`None` or `np.ndarray`): Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. """ depth_np: np.ndarray depth_colored: Union[None, Image.Image] uncertainty: Union[None, np.ndarray] class MarigoldInpaintPipeline(DiffusionPipeline): """ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the depth latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps to and from latent representations. scheduler (`DDIMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. scale_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in the model config. When used together with the `shift_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. shift_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in the model config. When used together with the `scale_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. default_denoising_steps (`int`, *optional*): The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable quality with the given model. This value must be set in the model config. When the pipeline is called without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure reasonable results with various model flavors compatible with the pipeline, such as those relying on very short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). default_processing_resolution (`int`, *optional*): The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in the model config. When the pipeline is called without explicitly setting `processing_resolution`, the default value is used. This is required to ensure reasonable results with various model flavors trained with varying optimal processing resolution values. """ rgb_latent_scale_factor = 0.18215 depth_latent_scale_factor = 0.18215 def __init__( self, unet: DoubleUNet2DConditionModel, vae: AutoencoderKL, scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, scale_invariant: Optional[bool] = True, shift_invariant: Optional[bool] = True, default_denoising_steps: Optional[int] = None, default_processing_resolution: Optional[int] = None, requires_safety_checker: bool = False, ): super().__init__() self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.register_to_config( scale_invariant=scale_invariant, shift_invariant=shift_invariant, default_denoising_steps=default_denoising_steps, default_processing_resolution=default_processing_resolution, ) self.scale_invariant = scale_invariant self.shift_invariant = shift_invariant self.default_denoising_steps = default_denoising_steps self.default_processing_resolution = default_processing_resolution self.rgb_scheduler = None self.depth_scheduler = None self.empty_text_embed = None self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.separate_list = [0,0] @torch.no_grad() def __call__( self, input_image: Union[Image.Image, torch.Tensor], denoising_steps: Optional[int] = None, ensemble_size: int = 5, processing_res: Optional[int] = None, match_input_res: bool = True, resample_method: str = "bilinear", batch_size: int = 0, generator: Union[torch.Generator, None] = None, color_map: str = "Spectral", show_progress_bar: bool = True, ensemble_kwargs: Dict = None, ) -> MarigoldDepthOutput: """ Function invoked when calling the pipeline. Args: input_image (`Image`): Input RGB (or gray-scale) image. denoising_steps (`int`, *optional*, defaults to `None`): Number of denoising diffusion steps during inference. The default value `None` results in automatic selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 for Marigold-LCM models. ensemble_size (`int`, *optional*, defaults to `10`): Number of predictions to be ensembled. processing_res (`int`, *optional*, defaults to `None`): Effective processing resolution. When set to `0`, processes at the original image resolution. This produces crisper predictions, but may also lead to the overall loss of global context. The default value `None` resolves to the optimal value from the model config. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. Only valid if `processing_res` > 0. resample_method: (`str`, *optional*, defaults to `bilinear`): Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. generator (`torch.Generator`, *optional*, defaults to `None`) Random generator for initial noise generation. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): Colormap used to colorize the depth map. scale_invariant (`str`, *optional*, defaults to `True`): Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. shift_invariant (`str`, *optional*, defaults to `True`): Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m. ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. Returns: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. None if `ensemble_size = 1` """ # Model-specific optimal default values leading to fast and reasonable results. if denoising_steps is None: denoising_steps = self.default_denoising_steps if processing_res is None: processing_res = self.default_processing_resolution assert processing_res >= 0 assert ensemble_size >= 1 # Check if denoising step is reasonable self._check_inference_step(denoising_steps) resample_method: InterpolationMode = get_tv_resample_method(resample_method) # ----------------- Image Preprocess ----------------- # Convert to torch tensor if isinstance(input_image, Image.Image): input_image = input_image.convert("RGB") # convert to torch tensor [H, W, rgb] -> [rgb, H, W] rgb = pil_to_tensor(input_image) rgb = rgb.unsqueeze(0) # [1, rgb, H, W] elif isinstance(input_image, torch.Tensor): rgb = input_image else: raise TypeError(f"Unknown input type: {type(input_image) = }") input_size = rgb.shape assert ( 4 == rgb.dim() and 3 == input_size[-3] ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" # Resize image if processing_res > 0: rgb = resize_max_res( rgb, max_edge_resolution=processing_res, resample_method=resample_method, ) # Normalize rgb values rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] rgb_norm = rgb_norm.to(self.dtype) assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # ----------------- Predicting depth ----------------- # Batch repeated input image duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size > 0: _bs = batch_size else: _bs = find_batch_size( ensemble_size=ensemble_size, input_res=max(rgb_norm.shape[1:]), dtype=self.dtype, ) single_rgb_loader = DataLoader( single_rgb_dataset, batch_size=_bs, shuffle=False ) # Predict depth maps (batched) depth_pred_ls = [] if show_progress_bar: iterable = tqdm( single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False ) else: iterable = single_rgb_loader for batch in iterable: (batched_img,) = batch depth_pred_raw = self.single_infer( rgb_in=batched_img, num_inference_steps=denoising_steps, show_pbar=show_progress_bar, generator=generator, ) depth_pred_ls.append(depth_pred_raw.detach()) depth_preds = torch.concat(depth_pred_ls, dim=0) torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: depth_pred, pred_uncert = ensemble_depth( depth_preds, scale_invariant=self.scale_invariant, shift_invariant=self.shift_invariant, max_res=50, **(ensemble_kwargs or {}), ) else: depth_pred = depth_preds pred_uncert = None # Resize back to original resolution if match_input_res: depth_pred = resize( depth_pred, input_size[-2:], interpolation=resample_method, antialias=True, ) # Convert to numpy depth_pred = depth_pred.squeeze() depth_pred = depth_pred.cpu().numpy() if pred_uncert is not None: pred_uncert = pred_uncert.squeeze().cpu().numpy() # Clip output range depth_pred = depth_pred.clip(0, 1) # Colorize if color_map is not None: depth_colored = colorize_depth_maps( depth_pred, 0, 1, cmap=color_map ).squeeze() # [3, H, W], value in (0, 1) depth_colored = (depth_colored * 255).astype(np.uint8) depth_colored_hwc = chw2hwc(depth_colored) depth_colored_img = Image.fromarray(depth_colored_hwc) else: depth_colored_img = None return MarigoldDepthOutput( depth_np=depth_pred, depth_colored=depth_colored_img, uncertainty=pred_uncert, ) def _replace_unet_conv_in(self): # replace the first layer to accept 8 in_channels _weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3] _bias = self.unet.conv_in.bias.clone() # [320] zero_weight = torch.zeros(_weight.shape).to(_weight.device) _weight = torch.cat([_weight, zero_weight], dim=1) # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) # half the activation magnitude # _weight *= 0.5 # new conv_in channel _n_convin_out_channel = self.unet.conv_in.out_channels _new_conv_in = Conv2d( 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) _new_conv_in.weight = Parameter(_weight) _new_conv_in.bias = Parameter(_bias) self.unet.conv_in = _new_conv_in logging.info("Unet conv_in layer is replaced") # replace config self.unet.config["in_channels"] = 8 logging.info("Unet config is updated") return def _replace_unet_conv_out(self): # replace the first layer to accept 8 in_channels _weight = self.unet.conv_out.weight.clone() # [8, 320, 3, 3] _bias = self.unet.conv_out.bias.clone() # [320] _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s) _bias = _bias.repeat((2)) # half the activation magnitude # new conv_in channel _n_convin_out_channel = self.unet.conv_out.out_channels _new_conv_out = Conv2d( _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) _new_conv_out.weight = Parameter(_weight) _new_conv_out.bias = Parameter(_bias) self.unet.conv_out = _new_conv_out logging.info("Unet conv_out layer is replaced") # replace config self.unet.config["out_channels"] = 8 logging.info("Unet config is updated") return def _check_inference_step(self, n_step: int) -> None: """ Check if denoising step is reasonable Args: n_step (`int`): denoising steps """ assert n_step >= 1 if isinstance(self.scheduler, DDIMScheduler): if n_step < 10: logging.warning( f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." ) elif isinstance(self.scheduler, LCMScheduler): if not 1 <= n_step <= 4: logging.warning( f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps." ) elif isinstance(self.scheduler, PNDMScheduler): if n_step < 10: logging.warning( f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." ) else: raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") def encode_empty_text(self): """ Encode text embedding for empty prompt """ prompt = "" text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) def encode_text(self, prompt): """ Encode text embedding for empty prompt """ text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) return text_embed def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image: """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images def full_depth_rgb_inpaint(self, rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale, ): depth_latent = self.encode_depth(depth_in) depth_mask = torch.zeros_like(image_mask) depth_mask_latent = self.encode_depth(depth_in) rgb_latent = torch.randn( depth_latent.shape, device=self.device, dtype=self.unet.dtype, generator=generator, ) * self.rgb_scheduler.init_noise_sigma rgb_mask = image_mask rgb_mask_latent = self.encode_rgb(rgb_in * (image_mask.squeeze() < 0.5), generator=generator) rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) for i, t in enumerate(timesteps): cat_latent = torch.cat( [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1 ).float() # [B, 9*2, h, w] latent_model_input = torch.cat([cat_latent] * 2) # predict the noise residual with torch.no_grad(): partial_noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, depth2rgb_scale=0.2 )[0] noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, # separate_list=self.separate_list )[0] # perform guidance rgb_pred_wo_depth_text = partial_noise_pred[0, :4, :, :] rgb_pred_wo_text = noise_pred[0, :4, :, :] rgb_pred = noise_pred[1, :4, :, :] noise_pred = rgb_pred_wo_depth_text + 2 * (rgb_pred_wo_text - rgb_pred_wo_depth_text) + 3 * (rgb_pred - rgb_pred_wo_text) # compute the previous noisy sample x_t -> x_t-1 rgb_latent = self.rgb_scheduler.step(noise_pred, t, rgb_latent).prev_sample return rgb_latent, depth_latent def full_rgb_depth_inpaint(self, rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale ): rgb_latent = self.encode_rgb(rgb_in) rgb_mask = torch.zeros_like(image_mask) rgb_mask_latent = self.encode_rgb(rgb_in) depth_latent = torch.randn( rgb_latent.shape, device=self.device, dtype=self.unet.dtype, generator=generator, ) * self.depth_scheduler.init_noise_sigma depth_mask = image_mask depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5)) rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) for i, t in enumerate(timesteps): cat_latent = torch.cat( [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1 ).float() # [B, 9*2, h, w] latent_model_input = torch.cat([cat_latent] * 2) # predict the noise residual with torch.no_grad(): partial_noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, rgb2depth_scale=0.2 )[0] noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, # separate_list=self.separate_list )[0] # compute the previous noisy sample x_t -> x_t-1 depth_pre_wo_rgb = partial_noise_pred[1, 4:, :, :] depth_pre = depth_pre_wo_rgb + 4 * (noise_pred[1, 4:, :, :] - depth_pre_wo_rgb) depth_latent = self.depth_scheduler.step(depth_pre, t, depth_latent, generator=generator).prev_sample return rgb_latent, depth_latent def joint_inpaint(self, rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale ): bs = rgb_in.shape[0] h, w = int(rgb_in.shape[-2]/8), int(rgb_in.shape[-1]/8) rgb_latent = torch.randn( [bs, 4, h, w], device=self.device, dtype=self.unet.dtype, generator=generator, ) * self.rgb_scheduler.init_noise_sigma rgb_mask = image_mask rgb_mask_latent = self.encode_rgb(rgb_in * (rgb_mask.squeeze() < 0.5), generator=generator) depth_latent = torch.randn( [bs, 4, h, w], device=self.device, dtype=self.unet.dtype, generator=generator, ) * self.depth_scheduler.init_noise_sigma depth_mask = image_mask depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5)) rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) for i, t in enumerate(timesteps): cat_latent = torch.cat( [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1 ).float() # [B, 9*2, h, w] latent_model_input = torch.cat([cat_latent] * 2) # predict the noise residual with torch.no_grad(): partial_noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, depth2rgb_scale=0, rgb2depth_scale=0.2 )[0] noise_pred = self.unet( latent_model_input, rgb_timestep=t, depth_timestep=t, encoder_hidden_states=text_embed, return_dict=False, )[0] # perform guidance noise_pred_untext_undual, noise_pred_undual = partial_noise_pred.chunk(2) noise_pred_untext, noise_pred_cond = noise_pred.chunk(2) # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) depth_noise_pred = noise_pred_undual + 3 * (noise_pred_cond - noise_pred_undual) rgb_latent = self.rgb_scheduler.step(noise_pred_cond[:, :4, :, :], t, rgb_latent, return_dict=False)[0] depth_latent = self.depth_scheduler.step(depth_noise_pred[:, 4:, :, :], t, depth_latent, generator=generator, return_dict=False)[0] return rgb_latent, depth_latent @torch.no_grad() def _rgbd_inpaint(self, input_image: [torch.Tensor, PIL.Image.Image], depth_image: [torch.Tensor, PIL.Image.Image], mask: [torch.Tensor, PIL.Image.Image], prompt: str = '', guidance_scale: float = 4.5, generator: Union[torch.Generator, None] = None, num_inference_steps: int = 50, resample_method: str = "bilinear", processing_res: int = 512, mode: str = 'full_depth_rgb_inpaint' ) -> PIL.Image: self._check_inference_step(num_inference_steps) resample_method: InterpolationMode = get_tv_resample_method(resample_method) # ----------------- encoder prompt ----------------- if isinstance(prompt, list): bs = len(prompt) batch_text_embed = [] for p in prompt: batch_text_embed.append(self.encode_text(p)) batch_text_embed = torch.cat(batch_text_embed, dim=0) elif isinstance(prompt, str): bs = 1 batch_text_embed = self.encode_text(prompt).unsqueeze(0) else: raise NotImplementedError if self.empty_text_embed is None: self.encode_empty_text() batch_empty_text_embed = self.empty_text_embed.repeat( (batch_text_embed.shape[0], 1, 1) ).to(self.device) # [B, 2, 1024] text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0) # ----------------- Image Preprocess ----------------- # Convert to torch tensor if isinstance(input_image, Image.Image): rgb_in = self.image_processor.preprocess(input_image, height=processing_res, width=processing_res).to(self.dtype).to(self.device) elif isinstance(input_image, torch.Tensor): rgb = input_image.unsqueeze(0) input_size = rgb.shape assert ( 4 == rgb.dim() and 3 == input_size[-3] ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" if processing_res > 0: rgb = resize(rgb, [processing_res, processing_res], resample_method, antialias=True) rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] rgb_in = rgb_norm.to(self.dtype).to(self.device) assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 if isinstance(depth_image, Image.Image): depth = pil_to_tensor(depth_image) depth = depth.unsqueeze(0) # [1, rgb, H, W] elif isinstance(depth_image, torch.Tensor): if len(depth_image.shape) == 3: depth = depth_image.unsqueeze(0) else: depth = depth_image # pdb.set_trace() depth = depth.repeat(1, 3, 1, 1) input_size = depth.shape assert ( 4 == depth.dim() and 3 == input_size[-3] ), f"Wrong input shape {input_size}, expected [1, 1, H, W]" if processing_res > 0: depth = resize(depth, [processing_res, processing_res], resample_method, antialias=True) depth_norm: torch.Tensor = (depth - depth.min()) / ( depth.max() - depth.min()) * 2.0 - 1.0 # [0, 255] -> [-1, 1] depth_in = depth_norm.to(self.dtype).to(self.device) assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0 if (mask.max() - mask.min()) != 0: mask = (mask - mask.min()) / (mask.max() - mask.min()) * 255 image_mask = self.mask_processor.preprocess(mask, height=processing_res, width=processing_res).to(self.device) self.rgb_scheduler.set_timesteps(num_inference_steps, device=self.device) self.depth_scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.rgb_scheduler.timesteps if mode == 'full_rgb_depth_inpaint': rgb_latent, depth_latent = self.full_rgb_depth_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale=guidance_scale) if mode == 'partial_depth_rgb_inpaint': rgb_latent, depth_latent = self.partial_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale=guidance_scale) if mode == 'full_depth_rgb_inpaint': rgb_latent, depth_latent = self.full_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale=guidance_scale) if mode == 'joint_inpaint': rgb_latent, depth_latent = self.joint_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, generator, guidance_scale=guidance_scale) image = self.decode_image(rgb_latent) image = self.numpy_to_pil(image)[0] d_image = self.decode_depth(depth_latent) d_image = d_image.cpu().permute(0, 2, 3, 1).numpy() d_image = (d_image - d_image.min()) / (d_image.max() - d_image.min()) d_image = self.numpy_to_pil(d_image)[0] depth = depth.squeeze().permute(1, 2, 0).cpu().numpy() depth = (depth - depth.min()) / (depth.max() - depth.min()) ori_depth = self.numpy_to_pil(depth)[0] ori_image = input_image.squeeze().permute(1, 2, 0).cpu().numpy() ori_image = self.numpy_to_pil(ori_image/255)[0] image_mask = self.numpy_to_pil(image_mask.permute(0, 2, 3, 1).cpu().numpy())[0] cat_image = make_image_grid([ori_image, ori_depth, image_mask, image, d_image], rows=1, cols=5) return cat_image def encode_rgb(self, rgb_in: torch.Tensor, generator=None) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ # encode image_latents = self.vae.encode(rgb_in).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ # encode h = self.vae.encoder(depth_in) moments = self.vae.quant_conv(h) mean, logvar = torch.chunk(moments, 2, dim=1) # scale latent depth_latent = mean * self.depth_latent_scale_factor return depth_latent def decode_image(self, latents): latents = 1 / self.vae.config.scaling_factor * latents z = self.vae.post_quant_conv(latents) image = self.vae.decoder(z) image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. Args: depth_latent (`torch.Tensor`): Depth latent to be decoded. Returns: `torch.Tensor`: Decoded depth map. """ # scale latent depth_latent = depth_latent / self.depth_latent_scale_factor # decode z = self.vae.post_quant_conv(depth_latent) stacked = self.vae.decoder(z) # mean of output channels depth_mean = stacked.mean(dim=1, keepdim=True) return depth_mean def post_process_rgbd(self, prompts, rgb_image, depth_image): rgbd_images = [] for idx, p in enumerate(prompts): image1, image2 = rgb_image[idx], depth_image[idx] width1, height1 = image1.size width2, height2 = image2.size font = ImageFont.load_default(size=20) text = p draw = ImageDraw.Draw(image1) text_bbox = draw.textbbox((0, 0), text, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255)) text_x = (new_image.width - text_width) // 2 text_y = 0 draw = ImageDraw.Draw(new_image) draw.text((text_x, text_y), text, fill="black", font=font) new_image.paste(image1, (0, text_height)) new_image.paste(image2, (width1, text_height)) rgbd_images.append(pil_to_tensor(new_image)) return rgbd_images