Accelerate documentation

Utilities for Fully Sharded Data Parallelism

You are viewing v0.34.2 version. A newer version v1.0.0rc1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Utilities for Fully Sharded Data Parallelism

accelerate.utils.enable_fsdp_ram_efficient_loading

< >

( )

Enables RAM efficient loading of Hugging Face models for FSDP in the environment.

accelerate.utils.disable_fsdp_ram_efficient_loading

< >

( )

Disables RAM efficient loading of Hugging Face models for FSDP in the environment.

accelerate.utils.merge_fsdp_weights

< >

( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )

Parameters

  • checkpoint_dir (str) — The directory containing the FSDP checkpoints (can be either the model or optimizer).
  • output_path (str) — The path to save the merged checkpoint.
  • safe_serialization (bool, optional, defaults to True) — Whether to save the merged weights with safetensors (recommended).
  • remove_checkpoint_dir (bool, optional, defaults to False) — Whether to remove the checkpoint directory after merging.

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

class accelerate.FullyShardedDataParallelPlugin

< >

( sharding_strategy: Union = None backward_prefetch: Union = None mixed_precision_policy: Union = None auto_wrap_policy: Union = None cpu_offload: Union = None ignored_modules: Optional = None state_dict_type: Union = None state_dict_config: Union = None optim_state_dict_config: Union = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: Optional = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: Optional = None min_num_params: Optional = None )

This plugin is used to enable fully sharded data parallelism.

set_auto_wrap_policy

< >

( model )

Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the transformer_cls_to_wrap

set_mixed_precision

< >

( mixed_precision buffer_autocast = False override = False )

Sets the mixed precision policy for FSDP

set_state_dict_type

< >

( state_dict_type = None )

Set the state dict config based on the StateDictType.

< > Update on GitHub