Add FSDP docs (#7791)
* Add FSDP docs * Address reviews * Add note about how FSDP can replace pipe parallelism * Add import * Remove sentence
This commit is contained in:
parent
e4ba06c70f
commit
0a72fd2284
|
@ -23,7 +23,7 @@ This means we cannot sacrifice throughput as much as if we were fine-tuning, bec
|
|||
Overall:
|
||||
|
||||
* When **fine-tuning** a model, use advanced memory efficient plugins such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute
|
||||
* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes
|
||||
* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded`, scaling the number of GPUs to reach larger parameter sizes
|
||||
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant
|
||||
|
||||
For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu plugins.
|
||||
|
@ -73,6 +73,104 @@ Sharded Training can work across all DDP variants by adding the additional ``--p
|
|||
|
||||
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
|
||||
|
||||
----------
|
||||
|
||||
.. _fully-sharded:
|
||||
|
||||
Fully Sharded Training
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. warning::
|
||||
Fully Sharded Training is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.
|
||||
|
||||
`Fully Sharded <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ shards optimizer state, gradients and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory.
|
||||
|
||||
Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort.
|
||||
|
||||
Shard Parameters to Reach 10+ Billion Parameters
|
||||
""""""""""""""""""""""""""""""""""""""""""""""""
|
||||
|
||||
To reach larger parameter sizes and be memory efficient, we have to shard parameters. There are various ways to enable this.
|
||||
|
||||
.. note::
|
||||
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
|
||||
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
|
||||
This is a limitation of Fully Sharded Training that will be resolved in the future.
|
||||
|
||||
Wrap the Model
|
||||
""""""""""""""
|
||||
|
||||
To activate parameter sharding, you must wrap your model using provided ``wrap`` or ``auto_wrap`` functions as described below. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly.
|
||||
|
||||
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins.
|
||||
|
||||
This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory.
|
||||
|
||||
``auto_wrap`` will recursively wrap `torch.nn.Modules` within the ``LightningModule`` with nested Fully Sharded Wrappers,
|
||||
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information `here <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html>`__).
|
||||
|
||||
``auto_wrap`` can have varying level of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**.
|
||||
|
||||
``wrap`` will simply wrap the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.
|
||||
|
||||
Below is an example of using both ``wrap`` and ``auto_wrap`` to create your model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
|
||||
|
||||
class MyModel(pl.LightningModule):
|
||||
...
|
||||
def configure_sharded_model(self):
|
||||
# Created within sharded model context, modules are instantly sharded across processes
|
||||
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``
|
||||
|
||||
# Wraps the layer in a Fully Sharded Wrapper automatically
|
||||
linear_layer = wrap(nn.Linear(32, 32))
|
||||
|
||||
# Wraps the module recursively
|
||||
# based on a minimum number of parameters (default 100M parameters)
|
||||
block = auto_wrap(
|
||||
nn.Sequential(
|
||||
nn.Linear(32, 32),
|
||||
nn.ReLU()
|
||||
)
|
||||
)
|
||||
|
||||
# For best memory efficiency,
|
||||
# add fairscale activation checkpointing
|
||||
final_block = auto_wrap(
|
||||
checkpoint_wrapper(
|
||||
nn.Sequential(
|
||||
nn.Linear(32, 32),
|
||||
nn.ReLU()
|
||||
)
|
||||
)
|
||||
)
|
||||
self.model = nn.Sequential(
|
||||
linear_layer,
|
||||
nn.ReLU(),
|
||||
block,
|
||||
final_block
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.AdamW(self.model.parameters())
|
||||
|
||||
model = MyModel()
|
||||
trainer = Trainer(gpus=4, plugins='fsdp', precision=16)
|
||||
trainer.fit(model)
|
||||
|
||||
trainer.test()
|
||||
trainer.predict()
|
||||
|
||||
|
||||
----------
|
||||
|
||||
.. _fairscale-activation-checkpointing:
|
||||
|
||||
FairScale Activation Checkpointing
|
||||
|
|
Loading…
Reference in New Issue