# Adapted from CogVideo # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # CogVideo: https://github.com/THUDM/CogVideo # diffusers: https://github.com/huggingface/diffusers # -------------------------------------------------------- from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.activations import get_activation from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import logging from diffusers.utils.accelerate_utils import apply_forward_hook from .modules import CogVideoXDownsample3D, CogVideoXUpsample3D logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CogVideoXSafeConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. """ def forward(self, input: torch.Tensor) -> torch.Tensor: memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 # Set to 2GB, suitable for CuDNN if memory_count > 2: kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 input_chunks = torch.chunk(input, part_num, dim=2) if kernel_size > 1: input_chunks = [input_chunks[0]] + [ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] output_chunks = [] for input_chunk in input_chunks: output_chunks.append(super().forward(input_chunk)) output = torch.cat(output_chunks, dim=2) return output else: return super().forward(input) class CogVideoXCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. Args: in_channels (int): Number of channels in the input tensor. out_channels (int): Number of output channels. kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. stride (int, optional): Stride of the convolution. Default is 1. dilation (int, optional): Dilation rate of the convolution. Default is 1. pad_mode (str, optional): Padding mode. Default is "constant". """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: int = 1, dilation: int = 1, pad_mode: str = "constant", ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 time_kernel_size, height_kernel_size, width_kernel_size = kernel_size self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) height_pad = height_kernel_size // 2 width_pad = width_kernel_size // 2 self.height_pad = height_pad self.width_pad = width_pad self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) self.temporal_dim = 2 self.time_kernel_size = time_kernel_size stride = (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, ) self.conv_cache = None def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: dim = self.temporal_dim kernel_size = self.time_kernel_size if kernel_size == 1: return inputs inputs = inputs.transpose(0, dim) if self.conv_cache is not None: inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0) else: inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0) inputs = inputs.transpose(0, dim).contiguous() return inputs def _clear_fake_context_parallel_cache(self): del self.conv_cache self.conv_cache = None def forward(self, inputs: torch.Tensor) -> torch.Tensor: input_parallel = self.fake_context_parallel_forward(inputs) self._clear_fake_context_parallel_cache() self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) output_parallel = self.conv(input_parallel) output = output_parallel return output class CogVideoXSpatialNorm3D(nn.Module): r""" Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific to 3D-video like data. CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. Args: f_channels (`int`): The number of channels for input to group normalization layer, and output of the spatial norm layer. zq_channels (`int`): The number of channels for the quantized vector as described in the paper. """ def __init__( self, f_channels: int, zq_channels: int, groups: int = 32, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] z_first = F.interpolate(z_first, size=f_first_size) z_rest = F.interpolate(z_rest, size=f_rest_size) zq = torch.cat([z_first, z_rest], dim=2) else: zq = F.interpolate(zq, size=f.shape[-3:]) norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f class CogVideoXResnetBlock3D(nn.Module): r""" A 3D ResNet block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (Optional[int], optional): Number of output channels. If None, defaults to `in_channels`. Default is None. dropout (float, optional): Dropout rate. Default is 0.0. temb_channels (int, optional): Number of time embedding channels. Default is 512. groups (int, optional): Number of groups for group normalization. Default is 32. eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. non_linearity (str, optional): Activation function to use. Default is "swish". conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. pad_mode (str, optional): Padding mode. Default is "first". """ def __init__( self, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, eps: float = 1e-6, non_linearity: str = "swish", conv_shortcut: bool = False, spatial_norm_dim: Optional[int] = None, pad_mode: str = "first", ): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.nonlinearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut if spatial_norm_dim is None: self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: self.norm1 = CogVideoXSpatialNorm3D( f_channels=in_channels, zq_channels=spatial_norm_dim, groups=groups, ) self.norm2 = CogVideoXSpatialNorm3D( f_channels=out_channels, zq_channels=spatial_norm_dim, groups=groups, ) self.conv1 = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if temb_channels > 0: self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = nn.Dropout(dropout) self.conv2 = CogVideoXCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) else: self.conv_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) def forward( self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs if zq is not None: hidden_states = self.norm1(hidden_states, zq) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: hidden_states = self.norm2(hidden_states, zq) else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: inputs = self.conv_shortcut(inputs) hidden_states = hidden_states + inputs return hidden_states class CogVideoXDownBlock3D(nn.Module): r""" A downsampling block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. compress_time (bool, optional): If True, apply temporal compression. Default is False. pad_mode (str, optional): Padding mode. Default is "first". """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, add_downsample: bool = True, downsample_padding: int = 0, compress_time: bool = False, pad_mode: str = "first", ): super().__init__() resnets = [] for i in range(num_layers): in_channel = in_channels if i == 0 else out_channels resnets.append( CogVideoXResnetBlock3D( in_channels=in_channel, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.downsamplers = None if add_downsample: self.downsamplers = nn.ModuleList( [ CogVideoXDownsample3D( out_channels, out_channels, padding=downsample_padding, compress_time=compress_time ) ] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class CogVideoXMidBlock3D(nn.Module): r""" A middle block used in the CogVideoX model. Args: in_channels (int): Number of input channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. pad_mode (str, optional): Padding mode. Default is "first". """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, spatial_norm_dim: Optional[int] = None, pad_mode: str = "first", ): super().__init__() resnets = [] for _ in range(num_layers): resnets.append( CogVideoXResnetBlock3D( in_channels=in_channels, out_channels=in_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, spatial_norm_dim=spatial_norm_dim, non_linearity=resnet_act_fn, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) return hidden_states class CogVideoXUpBlock3D(nn.Module): r""" An upsampling block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. compress_time (bool, optional): If True, apply temporal compression. Default is False. pad_mode (str, optional): Padding mode. Default is "first". """ def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, spatial_norm_dim: int = 16, add_upsample: bool = True, upsample_padding: int = 1, compress_time: bool = False, pad_mode: str = "first", ): super().__init__() resnets = [] for i in range(num_layers): in_channel = in_channels if i == 0 else out_channels resnets.append( CogVideoXResnetBlock3D( in_channels=in_channel, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.upsamplers = None if add_upsample: self.upsamplers = nn.ModuleList( [CogVideoXUpsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class CogVideoXEncoder3D(nn.Module): r""" The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 3, out_channels: int = 16, down_block_types: Tuple[str, ...] = ( "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() # log2 of temporal_compress_times temporal_compress_level = int(np.log2(temporal_compression_ratio)) self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) self.down_blocks = nn.ModuleList([]) # down blocks output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if down_block_type == "CogVideoXDownBlock3D": down_block = CogVideoXDownBlock3D( in_channels=input_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_downsample=not is_final_block, compress_time=compress_time, ) else: raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") self.down_blocks.append(down_block) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=block_out_channels[-1], temb_channels=0, dropout=dropout, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, pad_mode=pad_mode, ) self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Down for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), hidden_states, temb, None ) # 2. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, None ) else: # 1. Down for down_block in self.down_blocks: hidden_states = down_block(hidden_states, temb, None) # 2. Mid hidden_states = self.mid_block(hidden_states, temb, None) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class CogVideoXDecoder3D(nn.Module): r""" The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 16, out_channels: int = 3, up_block_types: Tuple[str, ...] = ( "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() reversed_block_out_channels = list(reversed(block_out_channels)) self.conv_in = CogVideoXCausalConv3d( in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode ) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=reversed_block_out_channels[0], temb_channels=0, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, pad_mode=pad_mode, ) # up blocks self.up_blocks = nn.ModuleList([]) output_channel = reversed_block_out_channels[0] temporal_compress_level = int(np.log2(temporal_compression_ratio)) for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if up_block_type == "CogVideoXUpBlock3D": up_block = CogVideoXUpBlock3D( in_channels=prev_output_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block + 1, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, add_upsample=not is_final_block, compress_time=compress_time, pad_mode=pad_mode, ) prev_output_channel = output_channel else: raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") self.up_blocks.append(up_block) self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, sample ) # 2. Up for up_block in self.up_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), hidden_states, temb, sample ) else: # 1. Mid hidden_states = self.mid_block(hidden_states, temb, sample) # 2. Up for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, sample) # 3. Post-process hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [CogVideoX](https://github.com/THUDM/CogVideo). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. force_upcast (`bool`, *optional*, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix """ _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXResnetBlock3D"] @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ( "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", ), up_block_types: Tuple[str] = ( "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", ), block_out_channels: Tuple[int] = (128, 256, 256, 512), latent_channels: int = 16, layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, temporal_compression_ratio: float = 4, sample_size: int = 256, scaling_factor: float = 1.15258426, shift_factor: Optional[float] = None, latents_mean: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, use_quant_conv: bool = False, use_post_quant_conv: bool = False, ): super().__init__() self.encoder = CogVideoXEncoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_eps=norm_eps, norm_num_groups=norm_num_groups, temporal_compression_ratio=temporal_compression_ratio, ) self.decoder = CogVideoXDecoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_eps=norm_eps, norm_num_groups=norm_num_groups, temporal_compression_ratio=temporal_compression_ratio, ) self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value def clear_fake_context_parallel_cache(self): for name, module in self.named_modules(): if isinstance(module, CogVideoXCausalConv3d): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encoder(x) if self.quant_conv is not None: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.post_quant_conv is not None: z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z) if not return_dict: return (dec,) return dec