diff --git a/docs/source/advanced/advanced_gpu.rst b/docs/source/advanced/advanced_gpu.rst index 8146744b52..0e43d4bff4 100644 --- a/docs/source/advanced/advanced_gpu.rst +++ b/docs/source/advanced/advanced_gpu.rst @@ -23,7 +23,7 @@ This means we cannot sacrifice throughput as much as if we were fine-tuning, bec Overall: * When **fine-tuning** a model, use advanced memory efficient plugins such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute -* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes +* When **pre-training** a model, use simpler optimizations such :ref:`sharded`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded`, scaling the number of GPUs to reach larger parameter sizes * For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu plugins. @@ -73,6 +73,104 @@ Sharded Training can work across all DDP variants by adding the additional ``--p Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. +---------- + +.. _fully-sharded: + +Fully Sharded Training +^^^^^^^^^^^^^^^^^^^^^^ + +.. warning:: + Fully Sharded Training is in beta and the API is subject to change. Please create an `issue `_ if you run into any issues. + +`Fully Sharded `__ shards optimizer state, gradients and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory. + +Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort. + +Shard Parameters to Reach 10+ Billion Parameters +"""""""""""""""""""""""""""""""""""""""""""""""" + +To reach larger parameter sizes and be memory efficient, we have to shard parameters. There are various ways to enable this. + +.. note:: + Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``. + This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``. + This is a limitation of Fully Sharded Training that will be resolved in the future. + +Wrap the Model +"""""""""""""" + +To activate parameter sharding, you must wrap your model using provided ``wrap`` or ``auto_wrap`` functions as described below. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly. + +When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins. + +This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory. + +``auto_wrap`` will recursively wrap `torch.nn.Modules` within the ``LightningModule`` with nested Fully Sharded Wrappers, +signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information `here `__). + +``auto_wrap`` can have varying level of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**. + +``wrap`` will simply wrap the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager. + +Below is an example of using both ``wrap`` and ``auto_wrap`` to create your model. + +.. code-block:: python + + import torch + import torch.nn as nn + import pytorch_lightning as pl + from pytorch_lightning import Trainer + from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap + + class MyModel(pl.LightningModule): + ... + def configure_sharded_model(self): + # Created within sharded model context, modules are instantly sharded across processes + # as soon as they are wrapped with ``wrap`` or ``auto_wrap`` + + # Wraps the layer in a Fully Sharded Wrapper automatically + linear_layer = wrap(nn.Linear(32, 32)) + + # Wraps the module recursively + # based on a minimum number of parameters (default 100M parameters) + block = auto_wrap( + nn.Sequential( + nn.Linear(32, 32), + nn.ReLU() + ) + ) + + # For best memory efficiency, + # add fairscale activation checkpointing + final_block = auto_wrap( + checkpoint_wrapper( + nn.Sequential( + nn.Linear(32, 32), + nn.ReLU() + ) + ) + ) + self.model = nn.Sequential( + linear_layer, + nn.ReLU(), + block, + final_block + ) + + def configure_optimizers(self): + return torch.optim.AdamW(self.model.parameters()) + + model = MyModel() + trainer = Trainer(gpus=4, plugins='fsdp', precision=16) + trainer.fit(model) + + trainer.test() + trainer.predict() + + +---------- + .. _fairscale-activation-checkpointing: FairScale Activation Checkpointing