Accelerator API docs (#6936)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-04-10 08:55:07 +02:00 committed by GitHub
parent b85cfbe8f3
commit 20ff50caa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 125 additions and 60 deletions

View File

@ -1,6 +1,21 @@
API References
==============
Accelerator API
---------------
.. currentmodule:: pytorch_lightning.accelerators
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
Accelerator
CPUAccelerator
GPUAccelerator
TPUAccelerator
Core API
--------

View File

@ -1,10 +1,56 @@
.. _accelerators:
############
Accelerators
############
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
also manage distributed accelerators (like DP, DDP, HPC cluster).
Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary
also manage distributed communication through :ref:`Plugins` (like DP, DDP, HPC cluster) and
can also be configured to run on arbitrary clusters or to link up to arbitrary
computational strategies like 16-bit precision via AMP and Apex.
**For help setting up custom plugin/accelerator please reach out to us at support@pytorchlightning.ai**
An Accelerator is meant to deal with one type of hardware.
Currently there are accelerators for:
- CPU
- GPU
- TPU
Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
.. testcode::
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
accelerator = GPUAccelerator(
precision_plugin=NativeMixedPrecisionPlugin(),
training_type_plugin=DDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new
hardware and distributed training or clusters.
.. warning:: The Accelerator API is in beta and subject to change.
For help setting up custom plugins/accelerators, please reach out to us at **support@pytorchlightning.ai**
----------
Accelerator API
---------------
.. currentmodule:: pytorch_lightning.accelerators
.. autosummary::
:nosignatures:
:template: classtemplate.rst
Accelerator
CPUAccelerator
GPUAccelerator
TPUAccelerator

View File

@ -1,3 +1,5 @@
.. _plugins:
#######
Plugins
#######

View File

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
@ -26,11 +27,6 @@ from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
if TYPE_CHECKING:
from torch.cuda.amp import GradScaler
from pytorch_lightning.trainer.trainer import Trainer
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
@ -40,6 +36,7 @@ class Accelerator(object):
An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for:
- CPU
- GPU
- TPU
@ -79,9 +76,10 @@ class Accelerator(object):
"""
self.training_type_plugin.setup_environment()
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
"""
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance
model: the LightningModule
@ -91,23 +89,23 @@ class Accelerator(object):
self.setup_optimizers(trainer)
self.setup_precision_plugin(self.precision_plugin)
def start_training(self, trainer: 'Trainer') -> None:
def start_training(self, trainer: 'pl.Trainer') -> None:
self.training_type_plugin.start_training(trainer)
def start_evaluating(self, trainer: 'Trainer') -> None:
def start_evaluating(self, trainer: 'pl.Trainer') -> None:
self.training_type_plugin.start_evaluating(trainer)
def start_predicting(self, trainer: 'Trainer') -> None:
def start_predicting(self, trainer: 'pl.Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)
def pre_dispatch(self, trainer: 'Trainer') -> None:
def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()
def post_dispatch(self, trainer: 'Trainer') -> None:
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
@ -169,12 +167,13 @@ class Accelerator(object):
Args:
args: the arguments for the models training step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
"""
args[0] = self.to_device(args[0])
@ -190,11 +189,12 @@ class Accelerator(object):
Args:
args: the arguments for the models validation step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
batch = self.to_device(args[0])
@ -208,11 +208,12 @@ class Accelerator(object):
Args:
args: the arguments for the models test step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch.
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
batch = self.to_device(args[0])
@ -226,11 +227,13 @@ class Accelerator(object):
Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch.
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
batch = self.to_device(args[0])
@ -336,7 +339,7 @@ class Accelerator(object):
"""Hook to do something at the end of the training"""
pass
def setup_optimizers(self, trainer: 'Trainer') -> None:
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
"""creates optimizers and schedulers
Args:
@ -385,7 +388,7 @@ class Accelerator(object):
return self.precision_plugin.precision
@property
def scaler(self) -> Optional['GradScaler']:
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:
return getattr(self.precision_plugin, 'scaler', None)
@ -423,6 +426,7 @@ class Accelerator(object):
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
@ -451,7 +455,8 @@ class Accelerator(object):
shard the model instantly - useful for extremely large models. Can save memory and
initialization time.
Returns: Model parallel context.
Returns:
Model parallel context.
"""
with self.training_type_plugin.model_sharded_context():
yield
@ -498,7 +503,9 @@ class Accelerator(object):
"""
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
Returns:
True if we want to call the model parallel setup hook.
"""
return self.training_type_plugin.call_configure_sharded_model_hook
@ -512,7 +519,9 @@ class Accelerator(object):
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
Returns:
If True, delay setup optimizers until `pre_dispatch`, else call within `setup`.
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch

View File

@ -11,20 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
class CPUAccelerator(Accelerator):
""" Accelerator for CPU devices. """
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
Raises:
MisconfigurationException:

View File

@ -13,24 +13,22 @@
# limitations under the License.
import logging
import os
from typing import Any, TYPE_CHECKING
from typing import Any
import torch
import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
_log = logging.getLogger(__name__)
class GPUAccelerator(Accelerator):
""" Accelerator for GPU devices. """
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
Raises:
MisconfigurationException:

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, TYPE_CHECKING, Union
from typing import Any, Callable, Union
from torch.optim import Optimizer
@ -28,14 +28,13 @@ if _XLA_AVAILABLE:
xla_clip_grad_norm_ = clip_grad_norm_
if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
import pytorch_lightning as pl
class TPUAccelerator(Accelerator):
""" Accelerator for TPU devices. """
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
Raises:
MisconfigurationException: