from functools import partial import torch from opensora.registry import SCHEDULERS from . import gaussian_diffusion as gd from .respace import SpacedDiffusion, space_timesteps @SCHEDULERS.register_module("iddpm") class IDDPM(SpacedDiffusion): def __init__( self, num_sampling_steps=None, timestep_respacing=None, noise_schedule="linear", use_kl=False, sigma_small=False, predict_xstart=False, learn_sigma=True, rescale_learned_sigmas=False, diffusion_steps=1000, cfg_scale=4.0, ): betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) if use_kl: loss_type = gd.LossType.RESCALED_KL elif rescale_learned_sigmas: loss_type = gd.LossType.RESCALED_MSE else: loss_type = gd.LossType.MSE if num_sampling_steps is not None: assert timestep_respacing is None timestep_respacing = str(num_sampling_steps) if timestep_respacing is None or timestep_respacing == "": timestep_respacing = [diffusion_steps] super().__init__( use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), betas=betas, model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), model_var_type=( (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), loss_type=loss_type, # rescale_timesteps=rescale_timesteps, ) self.cfg_scale = cfg_scale def sample( self, model, text_encoder, z_size, prompts, device, additional_args=None, ): n = len(prompts) z = torch.randn(n, *z_size, device=device) z = torch.cat([z, z], 0) model_args = text_encoder.encode(prompts) y_null = text_encoder.null(n) model_args["y"] = torch.cat([model_args["y"], y_null], 0) if additional_args is not None: model_args.update(additional_args) forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale) samples = self.p_sample_loop( forward, z.shape, z, clip_denoised=False, model_kwargs=model_args, progress=True, device=device, ) samples, _ = samples.chunk(2, dim=0) return samples def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = model.forward(combined, timestep, y, **kwargs) model_out = model_out["x"] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1)