|
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
|
|
from ..core import flatten_dict
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""
|
|
Arguments which define the model and tokenizer to load.
|
|
"""
|
|
|
|
model_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": ("The model checkpoint for weights initialization.")},
|
|
)
|
|
model_revision: str = field(
|
|
default="main",
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
)
|
|
torch_dtype: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."),
|
|
"choices": ["auto", "bfloat16", "float16", "float32"],
|
|
},
|
|
)
|
|
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
|
attn_implementation: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")},
|
|
)
|
|
use_peft: bool = field(
|
|
default=False,
|
|
metadata={"help": ("Whether to use PEFT or not for training.")},
|
|
)
|
|
lora_r: Optional[int] = field(
|
|
default=16,
|
|
metadata={"help": ("LoRA R value.")},
|
|
)
|
|
lora_alpha: Optional[int] = field(
|
|
default=32,
|
|
metadata={"help": ("LoRA alpha.")},
|
|
)
|
|
lora_dropout: Optional[float] = field(
|
|
default=0.05,
|
|
metadata={"help": ("LoRA dropout.")},
|
|
)
|
|
lora_target_modules: Optional[List[str]] = field(
|
|
default=None,
|
|
metadata={"help": ("LoRA target modules.")},
|
|
)
|
|
lora_modules_to_save: Optional[List[str]] = field(
|
|
default=None,
|
|
metadata={"help": ("Model layers to unfreeze & train")},
|
|
)
|
|
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"})
|
|
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"})
|
|
|
|
bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
|
|
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
|
|
|
def to_dict(self):
|
|
output_dict = {}
|
|
for key, value in self.__dict__.items():
|
|
output_dict[key] = value
|
|
return flatten_dict(output_dict)
|
|
|
|
def __post_init__(self):
|
|
if self.load_in_8bit and self.load_in_4bit:
|
|
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
|
|
|