Accelerate documentation

Utilities for Fully Sharded Data Parallelism

You are viewing v0.34.2 version. A newer version v1.0.0rc1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Utilities for Fully Sharded Data Parallelism


< >

( )

Enables RAM efficient loading of Hugging Face models for FSDP in the environment.


< >

( )

Disables RAM efficient loading of Hugging Face models for FSDP in the environment.


< >

( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )


  • checkpoint_dir (str) — The directory containing the FSDP checkpoints (can be either the model or optimizer).
  • output_path (str) — The path to save the merged checkpoint.
  • safe_serialization (bool, optional, defaults to True) — Whether to save the merged weights with safetensors (recommended).
  • remove_checkpoint_dir (bool, optional, defaults to False) — Whether to remove the checkpoint directory after merging.

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

class accelerate.FullyShardedDataParallelPlugin

< >

( sharding_strategy: Union = None backward_prefetch: Union = None mixed_precision_policy: Union = None auto_wrap_policy: Union = None cpu_offload: Union = None ignored_modules: Optional = None state_dict_type: Union = None state_dict_config: Union = None optim_state_dict_config: Union = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: Optional = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: Optional = None min_num_params: Optional = None )

This plugin is used to enable fully sharded data parallelism.


< >

( model )

Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the transformer_cls_to_wrap


< >

( mixed_precision buffer_autocast = False override = False )

Sets the mixed precision policy for FSDP


< >

( state_dict_type = None )

Set the state dict config based on the StateDictType.

< > Update on GitHub