File size: 4,932 Bytes
252711e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional

from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available


@dataclass
class DDPOConfig:
    """

    Configuration class for DDPOTrainer

    """

    # common parameters
    exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
    """the name of this experiment (by default is the file name without the extension name)"""
    run_name: Optional[str] = ""
    """Run name for wandb logging and checkpoint saving."""
    seed: int = 0
    """Seed value for random generations"""
    log_with: Optional[Literal["wandb", "tensorboard"]] = None
    """Log with either 'wandb' or 'tensorboard', check  https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
    tracker_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the tracker (e.g. wandb_project)"""
    accelerator_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the accelerator"""
    project_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
    tracker_project_name: str = "trl"
    """Name of project to use for tracking"""
    logdir: str = "logs"
    """Top-level logging directory for checkpoint saving."""

    # hyperparameters
    num_epochs: int = 100
    """Number of epochs to train."""
    save_freq: int = 1
    """Number of epochs between saving model checkpoints."""
    num_checkpoint_limit: int = 5
    """Number of checkpoints to keep before overwriting old ones."""
    mixed_precision: str = "fp16"
    """Mixed precision training."""
    allow_tf32: bool = True
    """Allow tf32 on Ampere GPUs."""
    resume_from: Optional[str] = ""
    """Resume training from a checkpoint."""
    sample_num_steps: int = 50
    """Number of sampler inference steps."""
    sample_eta: float = 1.0
    """Eta parameter for the DDIM sampler."""
    sample_guidance_scale: float = 5.0
    """Classifier-free guidance weight."""
    sample_batch_size: int = 1
    """Batch size (per GPU!) to use for sampling."""
    sample_num_batches_per_epoch: int = 2
    """Number of batches to sample per epoch."""
    train_batch_size: int = 1
    """Batch size (per GPU!) to use for training."""
    train_use_8bit_adam: bool = False
    """Whether to use the 8bit Adam optimizer from bitsandbytes."""
    train_learning_rate: float = 3e-4
    """Learning rate."""
    train_adam_beta1: float = 0.9
    """Adam beta1."""
    train_adam_beta2: float = 0.999
    """Adam beta2."""
    train_adam_weight_decay: float = 1e-4
    """Adam weight decay."""
    train_adam_epsilon: float = 1e-8
    """Adam epsilon."""
    train_gradient_accumulation_steps: int = 1
    """Number of gradient accumulation steps."""
    train_max_grad_norm: float = 1.0
    """Maximum gradient norm for gradient clipping."""
    train_num_inner_epochs: int = 1
    """Number of inner epochs per outer epoch."""
    train_cfg: bool = True
    """Whether or not to use classifier-free guidance during training."""
    train_adv_clip_max: float = 5
    """Clip advantages to the range."""
    train_clip_range: float = 1e-4
    """The PPO clip range."""
    train_timestep_fraction: float = 1.0
    """The fraction of timesteps to train on."""
    per_prompt_stat_tracking: bool = False
    """Whether to track statistics for each prompt separately."""
    per_prompt_stat_tracking_buffer_size: int = 16
    """Number of reward values to store in the buffer for each prompt."""
    per_prompt_stat_tracking_min_count: int = 16
    """The minimum number of reward values to store in the buffer."""
    async_reward_computation: bool = False
    """Whether to compute rewards asynchronously."""
    max_workers: int = 2
    """The maximum number of workers to use for async reward computation."""
    negative_prompts: Optional[str] = ""
    """Comma-separated list of prompts to use as negative examples."""

    def to_dict(self):
        output_dict = {}
        for key, value in self.__dict__.items():
            output_dict[key] = value
        return flatten_dict(output_dict)

    def __post_init__(self):
        if self.log_with not in ["wandb", "tensorboard"]:
            warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."))

        if self.log_with == "wandb" and not is_torchvision_available():
            warnings.warn("Wandb image logging requires torchvision to be installed")

        if self.train_use_8bit_adam and not is_bitsandbytes_available():
            raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.")