Add FSDP docs (#7791)

* Add FSDP docs

* Address reviews

* Add note about how FSDP can replace pipe parallelism

* Add import

* Remove sentence
This commit is contained in:
Sean Naren 2021-06-02 10:52:48 +01:00 committed by GitHub
parent e4ba06c70f
commit 0a72fd2284
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 99 additions and 1 deletions

View File

@ -23,7 +23,7 @@ This means we cannot sacrifice throughput as much as if we were fine-tuning, bec
Overall:
* When **fine-tuning** a model, use advanced memory efficient plugins such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute
* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes
* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded`, scaling the number of GPUs to reach larger parameter sizes
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant
For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu plugins.
@ -73,6 +73,104 @@ Sharded Training can work across all DDP variants by adding the additional ``--p
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
----------
.. _fully-sharded:
Fully Sharded Training
^^^^^^^^^^^^^^^^^^^^^^
.. warning::
Fully Sharded Training is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.
`Fully Sharded <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ shards optimizer state, gradients and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory.
Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort.
Shard Parameters to Reach 10+ Billion Parameters
""""""""""""""""""""""""""""""""""""""""""""""""
To reach larger parameter sizes and be memory efficient, we have to shard parameters. There are various ways to enable this.
.. note::
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.
Wrap the Model
""""""""""""""
To activate parameter sharding, you must wrap your model using provided ``wrap`` or ``auto_wrap`` functions as described below. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly.
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins.
This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory.
``auto_wrap`` will recursively wrap `torch.nn.Modules` within the ``LightningModule`` with nested Fully Sharded Wrappers,
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information `here <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html>`__).
``auto_wrap`` can have varying level of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**.
``wrap`` will simply wrap the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.
Below is an example of using both ``wrap`` and ``auto_wrap`` to create your model.
.. code-block:: python
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
class MyModel(pl.LightningModule):
...
def configure_sharded_model(self):
# Created within sharded model context, modules are instantly sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(nn.Linear(32, 32))
# Wraps the module recursively
# based on a minimum number of parameters (default 100M parameters)
block = auto_wrap(
nn.Sequential(
nn.Linear(32, 32),
nn.ReLU()
)
)
# For best memory efficiency,
# add fairscale activation checkpointing
final_block = auto_wrap(
checkpoint_wrapper(
nn.Sequential(
nn.Linear(32, 32),
nn.ReLU()
)
)
)
self.model = nn.Sequential(
linear_layer,
nn.ReLU(),
block,
final_block
)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters())
model = MyModel()
trainer = Trainer(gpus=4, plugins='fsdp', precision=16)
trainer.fit(model)
trainer.test()
trainer.predict()
----------
.. _fairscale-activation-checkpointing:
FairScale Activation Checkpointing