from dataclasses import dataclass, field from typing import List, Optional from ..core import flatten_dict @dataclass class ModelConfig: """ Arguments which define the model and tokenizer to load. """ model_name_or_path: Optional[str] = field( default=None, metadata={"help": ("The model checkpoint for weights initialization.")}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) torch_dtype: Optional[str] = field( default=None, metadata={ "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) attn_implementation: Optional[str] = field( default=None, metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, ) use_peft: bool = field( default=False, metadata={"help": ("Whether to use PEFT or not for training.")}, ) lora_r: Optional[int] = field( default=16, metadata={"help": ("LoRA R value.")}, ) lora_alpha: Optional[int] = field( default=32, metadata={"help": ("LoRA alpha.")}, ) lora_dropout: Optional[float] = field( default=0.05, metadata={"help": ("LoRA dropout.")}, ) lora_target_modules: Optional[List[str]] = field( default=None, metadata={"help": ("LoRA target modules.")}, ) lora_modules_to_save: Optional[List[str]] = field( default=None, metadata={"help": ("Model layers to unfreeze & train")}, ) load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) def to_dict(self): output_dict = {} for key, value in self.__dict__.items(): output_dict[key] = value return flatten_dict(output_dict) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: raise ValueError("You can't use 8 bit and 4 bit precision at the same time")