Transformers documentation
FullyShardedDataParallel
FullyShardedDataParallel
Fully Sharded Data Parallel (FSDP) is a parallelism method that combines the advantages of data and model parallelism for distributed training.
Unlike DistributedDataParallel (DDP), FSDP saves more memory because it doesn’t replicate a model on each GPU. It shards the models parameters, gradients and optimizer states across GPUs. Each model shard processes a portion of the data and the results are synchronized to speed up training.
This guide covers how to set up training a model with FSDP and Accelerate, a library for managing distributed training.
pip install accelerate
Configuration options
Always start by running the accelerate config command to help Accelerate set up the correct distributed training environment.
accelerate config
The section below discusses some of the more important FSDP configuration options. Learn more about other available options in the fsdp_config parameter.
Sharding strategy
FSDP offers several sharding strategies to distribute a model. Refer to the table below to help you choose the best strategy for your setup. Specify a strategy with the fsdp_sharding_strategy
parameter in the configuration file.
sharding strategy | description | parameter value |
---|---|---|
FULL_SHARD | shards model parameters, gradients, and optimizer states | 1 |
SHARD_GRAD_OP | shards gradients and optimizer states | 2 |
NO_SHARD | don’t shard the model | 3 |
HYBRID_SHARD | shards model parameters, gradients, and optimizer states within each GPU | 4 |
HYBRID_SHARD_ZERO2 | shards gradients and optimizer states within each GPU | 5 |
CPU offload
Offload model parameters and gradients when they aren’t being used to the CPU to save additional GPU memory. This is useful for scenarios where a model is too large even with FSDP.
Specify fsdp_offload_params: true
in the configuration file to enable offloading.
Wrapping policy
FSDP is applied by wrapping each layer in the network. The wrapping is usually applied in a nested way where the full weights are discarded after each forward pass to save memory for the next layer.
There are several wrapping policies available, but the auto wrapping policy is the simplest and doesn’t require any changes to your code. Specify fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
to wrap a Transformer layer and fsdp_transformer_layer_cls_to_wrap
to determine which layer to wrap (for example, BertLayer
).
Size-based wrapping is also available. If a layer exceeds a certain number of parameters, it is wrapped. Specify fsdp_wrap_policy: SIZED_BASED_WRAP
and min_num_param
to set the minimum number of parameters for a layer to be wrapped.
Checkpoints
Intermediate checkpoints should be saved as a sharded state dict because saving the full state dict - even with CPU offloading - is time consuming and can cause NCCL Timeout
errors due to indefinite hanging during broadcasting.
Specify fsdp_state_dict_type: SHARDED_STATE_DICT
in the configuration file to save the sharded state dict. Now you can resume training from the sharded state dict with load_state.
accelerator.load_state("directory/containing/checkpoints")
Once training is complete though, you should save the full state dict because the sharded state dict is only compatible with FSDP.
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(script_args.output_dir)
TPU
PyTorch XLA, a package for running PyTorch on XLA devices, enables FSDP on TPUs. Modify the configuration file to include the parameters below. Refer to the xla_fsdp_settings parameter for additional XLA-specific parameters you can configure for FSDP.
xla: True # must be set to True to enable PyTorch/XLA
xla_fsdp_settings: # XLA specific FSDP parameters
xla_fsdp_grad_ckpt: True # enable gradient checkpointing
Training
After running accelerate config, your configuration file should be ready. An example configuration file is shown below that fully shards the parameter, gradient and optimizer states on two GPUs. Your file may look different depending on how you set up your configuration.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Run the accelerate launch command to launch a training script with the FSDP configurations you chose in the configuration file.
accelerate launch my-training-script.py
It is also possible to directly specify some of the FSDP arguments in the command line.
accelerate launch --fsdp="full shard" --fsdp_config="path/to/fsdp_config/" my-training-script.py
Resources
FSDP is a powerful tool for training large models with fewer GPUs compared to other parallelism strategies. Refer to the following resources below to learn even more about FSDP.
- Follow along with the more in-depth Accelerate guide for FSDP.
- Read the Introducing PyTorch Fully Sharded Data Parallel (FSDP) API blog post.
- Read the Scaling PyTorch models on Cloud TPUs with FSDP blog post.