next / trl /trainer /ddpo_config.py
BiXie's picture
Upload 204 files
252711e verified
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
@dataclass
class DDPOConfig:
"""
Configuration class for DDPOTrainer
"""
# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
run_name: Optional[str] = ""
"""Run name for wandb logging and checkpoint saving."""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
logdir: str = "logs"
"""Top-level logging directory for checkpoint saving."""
# hyperparameters
num_epochs: int = 100
"""Number of epochs to train."""
save_freq: int = 1
"""Number of epochs between saving model checkpoints."""
num_checkpoint_limit: int = 5
"""Number of checkpoints to keep before overwriting old ones."""
mixed_precision: str = "fp16"
"""Mixed precision training."""
allow_tf32: bool = True
"""Allow tf32 on Ampere GPUs."""
resume_from: Optional[str] = ""
"""Resume training from a checkpoint."""
sample_num_steps: int = 50
"""Number of sampler inference steps."""
sample_eta: float = 1.0
"""Eta parameter for the DDIM sampler."""
sample_guidance_scale: float = 5.0
"""Classifier-free guidance weight."""
sample_batch_size: int = 1
"""Batch size (per GPU!) to use for sampling."""
sample_num_batches_per_epoch: int = 2
"""Number of batches to sample per epoch."""
train_batch_size: int = 1
"""Batch size (per GPU!) to use for training."""
train_use_8bit_adam: bool = False
"""Whether to use the 8bit Adam optimizer from bitsandbytes."""
train_learning_rate: float = 3e-4
"""Learning rate."""
train_adam_beta1: float = 0.9
"""Adam beta1."""
train_adam_beta2: float = 0.999
"""Adam beta2."""
train_adam_weight_decay: float = 1e-4
"""Adam weight decay."""
train_adam_epsilon: float = 1e-8
"""Adam epsilon."""
train_gradient_accumulation_steps: int = 1
"""Number of gradient accumulation steps."""
train_max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
train_num_inner_epochs: int = 1
"""Number of inner epochs per outer epoch."""
train_cfg: bool = True
"""Whether or not to use classifier-free guidance during training."""
train_adv_clip_max: float = 5
"""Clip advantages to the range."""
train_clip_range: float = 1e-4
"""The PPO clip range."""
train_timestep_fraction: float = 1.0
"""The fraction of timesteps to train on."""
per_prompt_stat_tracking: bool = False
"""Whether to track statistics for each prompt separately."""
per_prompt_stat_tracking_buffer_size: int = 16
"""Number of reward values to store in the buffer for each prompt."""
per_prompt_stat_tracking_min_count: int = 16
"""The minimum number of reward values to store in the buffer."""
async_reward_computation: bool = False
"""Whether to compute rewards asynchronously."""
max_workers: int = 2
"""The maximum number of workers to use for async reward computation."""
negative_prompts: Optional[str] = ""
"""Comma-separated list of prompts to use as negative examples."""
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."))
if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")
if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.")