Utilities for Fully Sharded Data Parallelism
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str
) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str
) — The path to save the merged checkpoint. - safe_serialization (
bool
, optional, defaults toTrue
) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool
, optional, defaults toFalse
) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT
was used for the model. Weights will be saved to {output_path}/model.safetensors
if
safe_serialization
else pytorch_model.bin
.
Note: this is a CPU-bound process.
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: Union = None backward_prefetch: Union = None mixed_precision_policy: Union = None auto_wrap_policy: Union = None cpu_offload: Union = None ignored_modules: Optional = None state_dict_type: Union = None state_dict_config: Union = None optim_state_dict_config: Union = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: Optional = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: Optional = None min_num_params: Optional = None )
This plugin is used to enable fully sharded data parallelism.
Given model
, creates an auto_wrap_policy
baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType
.