|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from transformers import TrainingArguments
|
|
|
|
|
|
@dataclass
|
|
class RewardConfig(TrainingArguments):
|
|
"""
|
|
RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
|
|
|
|
Using [`HfArgumentParser`] we can turn this class into
|
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
|
command line.
|
|
|
|
Parameters:
|
|
max_length (`int`, *optional*, defaults to `None`):
|
|
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
|
gradient_checkpointing (`bool`, *optional*, defaults to `True`):
|
|
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
|
"""
|
|
|
|
max_length: Optional[int] = None
|
|
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
|
|
|