import torch.nn as nn from ...util import append_dims, instantiate_from_config class Denoiser(nn.Module): def __init__(self, weighting_config, scaling_config): super().__init__() self.weighting = instantiate_from_config(weighting_config) self.scaling = instantiate_from_config(scaling_config) def possibly_quantize_sigma(self, sigma): return sigma def possibly_quantize_c_noise(self, c_noise): return c_noise def w(self, sigma): return self.weighting(sigma) def __call__(self, network, input, sigma, cond): sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) return network(input * c_in, c_noise, cond) * c_out + input * c_skip class DiscreteDenoiser(Denoiser): def __init__( self, weighting_config, scaling_config, num_idx, discretization_config, do_append_zero=False, quantize_c_noise=True, flip=True, ): super().__init__(weighting_config, scaling_config) sigmas = instantiate_from_config(discretization_config)( num_idx, do_append_zero=do_append_zero, flip=flip ) self.register_buffer("sigmas", sigmas) self.quantize_c_noise = quantize_c_noise def sigma_to_idx(self, sigma): dists = sigma - self.sigmas[:, None] return dists.abs().argmin(dim=0).view(sigma.shape) def idx_to_sigma(self, idx): return self.sigmas[idx] def possibly_quantize_sigma(self, sigma): return self.idx_to_sigma(self.sigma_to_idx(sigma)) def possibly_quantize_c_noise(self, c_noise): if self.quantize_c_noise: return self.sigma_to_idx(c_noise) else: return c_noise class DiscreteDenoiserWithControl(DiscreteDenoiser): def __call__(self, network, input, sigma, cond, control_scale): sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip