from dataclasses import dataclass from typing import List try: from typing import Literal except ImportError: from typing_extensions import Literal @dataclass class GPTConfig: n_layer: int n_head: int n_embedding: int vocab_size: int block_size: int embedding_percentage_drop: float attention_percentage_drop: float @dataclass class VQVAEConfig: beta: float num_embeddings: int embedding_dim: int @dataclass class AutoencoderConfig: z_channels: int channels: int channels_multiplier: List[int] num_res_blocks: int attention_resolution: List[int] resolution: int dropout: float @dataclass class DiscriminatorConfig: num_layers: int filters: int @dataclass class DiscriminatorLossConfig: loss: Literal["hinge, vanilla"] factor: float iter_start: int weight: float @dataclass class VQVAELossConfig: codebook_weight: float perceptual_weight: float @dataclass class LossConfig: discriminator: DiscriminatorLossConfig vqvae: VQVAELossConfig perceptual_loss: str @dataclass class ModelConfig: vqvae_config: VQVAEConfig autoencoder_config: AutoencoderConfig discriminator_config: DiscriminatorConfig loss_config: LossConfig