import mamba_ssm from transformers import PretrainedConfig mamba_config_defaults = mamba_ssm.models.config_mamba.MambaConfig() class MambaConfig(PretrainedConfig): model_type = "mamba" def __init__( self, d_model: int = mamba_config_defaults.d_model, fused_add_norm: bool = mamba_config_defaults.fused_add_norm, n_layer: int = mamba_config_defaults.n_layer, pad_vocab_size_multiple: int = mamba_config_defaults.pad_vocab_size_multiple, residual_in_fp32: bool = mamba_config_defaults.residual_in_fp32, rms_norm: bool = mamba_config_defaults.rms_norm, ssm_cfg: dict = mamba_config_defaults.ssm_cfg, vocab_size: int = mamba_config_defaults.vocab_size, **kwargs, ): self.d_model = d_model self.fused_add_norm = fused_add_norm self.n_layer = n_layer self.pad_vocab_size_multiple = pad_vocab_size_multiple self.residual_in_fp32 = residual_in_fp32 self.rms_norm = rms_norm self.ssm_cfg = ssm_cfg self.vocab_size = vocab_size super().__init__(**kwargs)