Accelerator API docs (#6936)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
b85cfbe8f3
commit
20ff50caa6
|
@ -1,6 +1,21 @@
|
|||
API References
|
||||
==============
|
||||
|
||||
Accelerator API
|
||||
---------------
|
||||
|
||||
.. currentmodule:: pytorch_lightning.accelerators
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
Accelerator
|
||||
CPUAccelerator
|
||||
GPUAccelerator
|
||||
TPUAccelerator
|
||||
|
||||
Core API
|
||||
--------
|
||||
|
||||
|
|
|
@ -1,10 +1,56 @@
|
|||
.. _accelerators:
|
||||
|
||||
############
|
||||
Accelerators
|
||||
############
|
||||
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
|
||||
also manage distributed accelerators (like DP, DDP, HPC cluster).
|
||||
|
||||
Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary
|
||||
also manage distributed communication through :ref:`Plugins` (like DP, DDP, HPC cluster) and
|
||||
can also be configured to run on arbitrary clusters or to link up to arbitrary
|
||||
computational strategies like 16-bit precision via AMP and Apex.
|
||||
|
||||
**For help setting up custom plugin/accelerator please reach out to us at support@pytorchlightning.ai**
|
||||
An Accelerator is meant to deal with one type of hardware.
|
||||
Currently there are accelerators for:
|
||||
|
||||
- CPU
|
||||
- GPU
|
||||
- TPU
|
||||
|
||||
Each Accelerator gets two plugins upon initialization:
|
||||
One to handle differences from the training routine and one to handle different precisions.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators import GPUAccelerator
|
||||
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
|
||||
|
||||
accelerator = GPUAccelerator(
|
||||
precision_plugin=NativeMixedPrecisionPlugin(),
|
||||
training_type_plugin=DDPPlugin(),
|
||||
)
|
||||
trainer = Trainer(accelerator=accelerator)
|
||||
|
||||
|
||||
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new
|
||||
hardware and distributed training or clusters.
|
||||
|
||||
|
||||
.. warning:: The Accelerator API is in beta and subject to change.
|
||||
For help setting up custom plugins/accelerators, please reach out to us at **support@pytorchlightning.ai**
|
||||
|
||||
----------
|
||||
|
||||
|
||||
Accelerator API
|
||||
---------------
|
||||
|
||||
.. currentmodule:: pytorch_lightning.accelerators
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
Accelerator
|
||||
CPUAccelerator
|
||||
GPUAccelerator
|
||||
TPUAccelerator
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
.. _plugins:
|
||||
|
||||
#######
|
||||
Plugins
|
||||
#######
|
||||
|
|
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
|
||||
|
@ -26,11 +27,6 @@ from pytorch_lightning.utilities import rank_zero_warn
|
|||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
|
||||
|
||||
|
||||
|
@ -40,6 +36,7 @@ class Accelerator(object):
|
|||
An Accelerator is meant to deal with one type of Hardware.
|
||||
|
||||
Currently there are accelerators for:
|
||||
|
||||
- CPU
|
||||
- GPU
|
||||
- TPU
|
||||
|
@ -79,9 +76,10 @@ class Accelerator(object):
|
|||
"""
|
||||
self.training_type_plugin.setup_environment()
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
|
||||
def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
|
||||
"""
|
||||
Setup plugins for the trainer fit and creates optimizers.
|
||||
|
||||
Args:
|
||||
trainer: the trainer instance
|
||||
model: the LightningModule
|
||||
|
@ -91,23 +89,23 @@ class Accelerator(object):
|
|||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin(self.precision_plugin)
|
||||
|
||||
def start_training(self, trainer: 'Trainer') -> None:
|
||||
def start_training(self, trainer: 'pl.Trainer') -> None:
|
||||
self.training_type_plugin.start_training(trainer)
|
||||
|
||||
def start_evaluating(self, trainer: 'Trainer') -> None:
|
||||
def start_evaluating(self, trainer: 'pl.Trainer') -> None:
|
||||
self.training_type_plugin.start_evaluating(trainer)
|
||||
|
||||
def start_predicting(self, trainer: 'Trainer') -> None:
|
||||
def start_predicting(self, trainer: 'pl.Trainer') -> None:
|
||||
self.training_type_plugin.start_predicting(trainer)
|
||||
|
||||
def pre_dispatch(self, trainer: 'Trainer') -> None:
|
||||
def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.pre_dispatch()
|
||||
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
|
||||
self.setup_optimizers(trainer)
|
||||
self.precision_plugin.pre_dispatch()
|
||||
|
||||
def post_dispatch(self, trainer: 'Trainer') -> None:
|
||||
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.post_dispatch()
|
||||
self.precision_plugin.post_dispatch()
|
||||
|
@ -169,12 +167,13 @@ class Accelerator(object):
|
|||
|
||||
Args:
|
||||
args: the arguments for the models training step. Can consist of the following:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): Integer displaying index of this batch
|
||||
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): Integer displaying index of this batch
|
||||
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
- hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||
|
||||
"""
|
||||
args[0] = self.to_device(args[0])
|
||||
|
@ -190,11 +189,12 @@ class Accelerator(object):
|
|||
|
||||
Args:
|
||||
args: the arguments for the models validation step. Can consist of the following:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): The index of this batch
|
||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple val dataloaders used)
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple val dataloaders used)
|
||||
"""
|
||||
batch = self.to_device(args[0])
|
||||
|
||||
|
@ -208,11 +208,12 @@ class Accelerator(object):
|
|||
|
||||
Args:
|
||||
args: the arguments for the models test step. Can consist of the following:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): The index of this batch.
|
||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple test dataloaders used).
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch.
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple test dataloaders used).
|
||||
"""
|
||||
batch = self.to_device(args[0])
|
||||
|
||||
|
@ -226,11 +227,13 @@ class Accelerator(object):
|
|||
|
||||
Args:
|
||||
args: the arguments for the models predict step. Can consist of the following:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): The index of this batch.
|
||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple predict dataloaders used).
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch.
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple predict dataloaders used).
|
||||
|
||||
"""
|
||||
batch = self.to_device(args[0])
|
||||
|
||||
|
@ -336,7 +339,7 @@ class Accelerator(object):
|
|||
"""Hook to do something at the end of the training"""
|
||||
pass
|
||||
|
||||
def setup_optimizers(self, trainer: 'Trainer') -> None:
|
||||
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
|
||||
"""creates optimizers and schedulers
|
||||
|
||||
Args:
|
||||
|
@ -385,7 +388,7 @@ class Accelerator(object):
|
|||
return self.precision_plugin.precision
|
||||
|
||||
@property
|
||||
def scaler(self) -> Optional['GradScaler']:
|
||||
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:
|
||||
|
||||
return getattr(self.precision_plugin, 'scaler', None)
|
||||
|
||||
|
@ -423,6 +426,7 @@ class Accelerator(object):
|
|||
tensor: tensor of shape (batch, ...)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
||||
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
|
@ -451,7 +455,8 @@ class Accelerator(object):
|
|||
shard the model instantly - useful for extremely large models. Can save memory and
|
||||
initialization time.
|
||||
|
||||
Returns: Model parallel context.
|
||||
Returns:
|
||||
Model parallel context.
|
||||
"""
|
||||
with self.training_type_plugin.model_sharded_context():
|
||||
yield
|
||||
|
@ -498,7 +503,9 @@ class Accelerator(object):
|
|||
"""
|
||||
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
|
||||
This is useful for when we want to shard the model once within fit.
|
||||
Returns: True if we want to call the model parallel setup hook.
|
||||
|
||||
Returns:
|
||||
True if we want to call the model parallel setup hook.
|
||||
"""
|
||||
return self.training_type_plugin.call_configure_sharded_model_hook
|
||||
|
||||
|
@ -512,7 +519,9 @@ class Accelerator(object):
|
|||
Override to delay setting optimizers and schedulers till after dispatch.
|
||||
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
|
||||
However this may break certain precision plugins such as APEX which require optimizers to be set.
|
||||
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
|
||||
|
||||
Returns:
|
||||
If True, delay setup optimizers until `pre_dispatch`, else call within `setup`.
|
||||
"""
|
||||
return self.training_type_plugin.setup_optimizers_in_pre_dispatch
|
||||
|
||||
|
|
|
@ -11,20 +11,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
|
||||
class CPUAccelerator(Accelerator):
|
||||
""" Accelerator for CPU devices. """
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
|
|
@ -13,24 +13,22 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.plugins import DataParallelPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPUAccelerator(Accelerator):
|
||||
""" Accelerator for GPU devices. """
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -28,14 +28,13 @@ if _XLA_AVAILABLE:
|
|||
|
||||
xla_clip_grad_norm_ = clip_grad_norm_
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class TPUAccelerator(Accelerator):
|
||||
""" Accelerator for TPU devices. """
|
||||
|
||||
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
|
||||
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
|
Loading…
Reference in New Issue