Add Strategy page to docs (#11441)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Kaushik B 2022-03-06 06:07:48 +05:30 committed by GitHub
parent ce956af4f2
commit a14783ea8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 119 additions and 64 deletions

View File

@ -45,40 +45,6 @@ to configure Lightning to integrate with a :ref:`custom-cluster`.
.. image:: ../_static/images/accelerator/overview.svg
**********************
Create a custom plugin
**********************
Expert users may choose to extend an existing plugin by overriding its methods ...
.. code-block:: python
from pytorch_lightning.strategies import DDPStrategy
class CustomDDPStrategy(DDPStrategy):
def configure_ddp(self):
self._model = MyCustomDistributedDataParallel(
self.model,
device_ids=...,
)
or by subclassing the base classes :class:`~pytorch_lightning.strategies.Strategy` or
:class:`~pytorch_lightning.plugins.precision.PrecisionPlugin` to create new ones. These custom plugins
can then be passed into the Trainer directly or via a (custom) accelerator:
.. code-block:: python
# custom plugins
trainer = Trainer(strategy=CustomDDPStrategy(), plugins=[CustomPrecisionPlugin()])
# fully custom accelerator and plugins
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
training_type_plugin = CustomDDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_type_plugin)
The full list of built-in plugins is listed below.
@ -86,35 +52,6 @@ The full list of built-in plugins is listed below.
For help setting up custom plugins/accelerators, please reach out to us at **support@pytorchlightning.ai**
----------
Training Strategies
-------------------
.. currentmodule:: pytorch_lightning.strategies
.. autosummary::
:nosignatures:
:template: classtemplate.rst
Strategy
SingleDeviceStrategy
ParallelStrategy
DataParallelStrategy
DDPStrategy
DDP2Strategy
DDPShardedStrategy
DDPSpawnShardedStrategy
DDPSpawnStrategy
BaguaStrategy
DeepSpeedStrategy
HorovodStrategy
SingleTPUStrategy
TPUSpawnStrategy
Precision Plugins
-----------------

View File

@ -0,0 +1,116 @@
.. _strategy:
########
Strategy
########
Strategy controls the model distribution across training, evaluation, and prediction to be used by the :doc:`Trainer <../common/trainer>`. It can be controlled by passing different
strategy with aliases (``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"`` and so on) as well as a custom strategy to the ``strategy`` parameter for Trainer.
The Strategy in PyTorch Lightning handles the following responsibilities:
* Launch and teardown of training processes (if applicable).
* Setup communication between processes (NCCL, GLOO, MPI, and so on).
* Provide a unified communication interface for reduction, broadcast, and so on.
* Owns the :class:`~pytorch_lightning.core.lightning.LightningModule`
* Handles/owns optimizers and schedulers.
:class:`~pytorch_lightning.strategies.strategy.Strategy` also manages the accelerator, precision, and checkpointing plugins.
****************************************
Training Strategies with Various Configs
****************************************
.. code-block:: python
# Training with the DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)
# Training with the custom DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy=DDPStrategy(...), accelerator="gpu", devices=4)
# Training with the DDP Spawn strategy using auto accelerator selection
trainer = Trainer(strategy="ddp_spawn", accelerator="auto", devices=4)
# Training with the DeepSpeed strategy on available GPUs
trainer = Trainer(strategy="deepspeed", accelerator="gpu", devices="auto")
# Training with the DDP strategy using 3 CPU processes
trainer = Trainer(strategy="ddp", accelerator="cpu", devices=3)
# Training with the DDP Spawn strategy on 8 TPU cores
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices=8)
# Training with the default IPU strategy on 8 IPUs
trainer = Trainer(accelerator="ipu", devices=8)
----------
************************
Create a Custom Strategy
************************
Expert users may choose to extend an existing strategy by overriding its methods.
.. code-block:: python
from pytorch_lightning.strategies import DDPStrategy
class CustomDDPStrategy(DDPStrategy):
def configure_ddp(self):
self.model = MyCustomDistributedDataParallel(
self.model,
device_ids=...,
)
or by subclassing the base class :class:`~pytorch_lightning.strategies.Strategy` to create new ones. These custom strategies
can then be passed into the ``Trainer`` directly via the ``strategy`` parameter.
.. code-block:: python
# custom plugins
trainer = Trainer(strategy=CustomDDPStrategy())
# fully custom accelerator and plugins
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
training_strategy = CustomDDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_strategy)
The complete list of built-in strategies is listed below.
----------
****************************
Built-In Training Strategies
****************************
.. currentmodule:: pytorch_lightning.strategies
.. autosummary::
:nosignatures:
:template: classtemplate.rst
BaguaStrategy
DDP2Strategy
DDPFullyShardedStrategy
DDPShardedStrategy
DDPSpawnShardedStrategy
DDPSpawnStrategy
DDPStrategy
DataParallelStrategy
DeepSpeedStrategy
HorovodStrategy
IPUStrategy
ParallelStrategy
SingleDeviceStrategy
SingleTPUStrategy
Strategy
TPUSpawnStrategy

View File

@ -52,6 +52,7 @@ PyTorch Lightning
extensions/datamodules
extensions/logging
extensions/plugins
extensions/strategy
extensions/loops
.. toctree::

View File

@ -47,7 +47,8 @@ else:
class TPUSpawnStrategy(DDPSpawnStrategy):
"""Strategy for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method."""
"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn`
method."""
strategy_name = "tpu_spawn"