Use Fully Shared Data Parallel (FSDP) to train large models with billions of parameters efficiently on multiple GPUs and across multiple machines.
..note:: This is an experimental feature.
Today, large models with billions of parameters are trained with many GPUs across several machines in parallel.
Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to train just a 30B parameter model (even with batch size 1 and 16-bit precision).
The memory consumption for training is generally made up of
1. the model parameters,
2. the layer activations (forward) and
3. the gradients (backward).
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
|
When the sum of these memory components exceed the VRAM of a single GPU, regular data-parallel training (DDP) can no longer be employed.
One of the methods that can alleviate this limitation is called **model-parallel** training, and known as **FSDP** in PyTorch, and in this guide, you will learn how to effectively scale large models with it.
----
***************************
Checklist: When to use FSDP
***************************
|
✅ I have multiple GPUs
✅ I have tried regular DDP training with batch size 1 but I run out of memory
Models that have many large layers like linear layers in LLMs, ViTs, etc. with >100M parameters will benefit the most from FSDP because the memory they consume through parameters, activations and corresponding optimizer states can be evenly split across all GPUs.
However, one should avoid splitting small layers that have a few thousand parameters because communication overhead would dominate and slow the training down.
We can specify a list of layer classes in the **wrapping policy** to inform FSDP which parameters it should wrap:
Verify that FSDP works with your model by comparing the peak memory usage printed in the CUDA memory summary (see example above) with regular DDP training.
You should see a decrease in allocated memory and a slight increase in iteration time:
..list-table:: Numbers were produced with A100 40GB GPUs, Lightning 2.1 and PyTorch 2.1.
The standard practice in PyTorch is to put all model parameters into CPU memory first and then in a second step move them to the GPU device.
However, the larger the model the longer these two steps take.
If you create the large model layers inside the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook, you can initialize very large models quickly and reduce memory peaks.
By default, FSDP will automatically shard 1) the model weights 2) the gradients during backward and 3) the optimizer states across all GPUs of the corresponding layers selected by the auto-wrap-policy.
You can configure the following options to trade-off memory for speed:
1. Try the default settings first (FULL_SHARD). This is the slowest but will save you the most memory.
2. Try SHARD_GRAD_OP. If you run out of memory, revert back to the default (FULL_SHARD). Otherwise you should expect to see an increase in iteration speed.
|
Here is the memory and speed impact for each option when configured in our example code:
..list-table:: Numbers were produced with A100 40GB GPUs, Lightning 2.1 and PyTorch 2.1.
If you are short on GPU memory because you are training large models with 10+ billion parameters or require extreme batch sizes, consider trading off speed for more memory by enabling activation checkpointing or CPU offload.
Activations, the intermediate outputs of layers, are stored during the forward pass and needed during the backward pass to compute the gradients.
By enabling activation checkpointing, we can choose to discard and recompute selected layer activations dynamically during the backward pass when they are required, instead of storing them throughout the forward pass.
While this approach may slightly reduce training speed, it significantly reduces memory consumption.
The freed-up memory can then be allocated to increase the model's capacity or accommodate larger batch sizes, resulting in potential performance improvements.
As in our example, it is typical to set the ``activation_checkpointing_policy`` the same as ``auto_wrap_policy``.
Offload parameters to CPU
=========================
The most drastic GPU memory savings can be achieved by offloading parameters to the CPU:
..code-block:: python
# Set `cpu_offload=True`
strategy = FSDPStrategy(..., cpu_offload=True)
trainer = L.Trainer(..., strategy=strategy)
The drawback is a much slower training speed due to the added communication between CPU and GPU for transferring parameters in every forward pass.
You should use this only if you have enough CPU memory and other scaling methods don’t give you enough memory savings.
In our example, we see a 3.5x memory saving, but a significant increase in iteration time:
..list-table:: Numbers were produced with A100 40GB GPUs, Lightning 2.1 and PyTorch 2.1.
:widths:25 25 25 25
:header-rows:1
* -
- DDP
- FSDP
- FSDP + CPU offload
* - Memory (MB)
- 23’125
- 9’627
- 2’790
* - Iterations per second
- 4.31
- 3.19
- 0.02
----
**********************************
Advanced performance optimizations
**********************************
If you’ve reached a good understanding of how the different FSDP settings impact the memory usage and speed of your model, here are a few more to squeeze out the last bit of performance.
These settings really depend on the specific use cases, so you will have to turn them on and off to see the impact on your model.
`See the full list of optimizers that support this <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
Limit all-gathers
=================
If you are running training close to the max.
GPU memory limit, you might be getting so-called CUDA malloc retries.
This is essentially the GPU running out of memory but before crashing completely, it tries to find some unused or cached memory it can free.
When they happen frequently, these retries can have a significant impact on speed.
Normally, you would decrease the batch size slightly to avoid it.
With FSDP, you have one more knob you can tweak to combat the issue, by setting ``limit_all_gathers=True``:
..code-block:: python
strategy = FSDPStrategy(
# Default: The CPU will schedule the transfer of weights between GPUs
# at will, sometimes too aggressively
limit_all_gathers=False,
# Enable this if you are close to the max. GPU memory usage
limit_all_gathers=True,
)
trainer = L.Trainer(..., strategy=strategy)
You can monitor CUDA malloc retries in the output of ``torch.cuda.memory_summary()`` for example, or through the PyTorch profiler.
Manual wrapping
===============
Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model.
To activate parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function.
Internally in Lightning, we enable a context manager around the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook to make sure the ``wrap`` parameters are passed correctly.
Here is an example that uses ``wrap`` to create a model: