next / trl /trainer /model_config.py
BiXie's picture
Upload 204 files
252711e verified
from dataclasses import dataclass, field
from typing import List, Optional
from ..core import flatten_dict
@dataclass
class ModelConfig:
"""
Arguments which define the model and tokenizer to load.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": ("The model checkpoint for weights initialization.")},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
attn_implementation: Optional[str] = field(
default=None,
metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"})
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"})
bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")