lightning/pytorch_lightning/accelerators/accelerator.py

424 lines
17 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
if TYPE_CHECKING:
from torch.cuda.amp import GradScaler
from pytorch_lightning.trainer.trainer import Trainer
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
class Accelerator(object):
"""
The Accelerator Base Class.
An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for:
- CPU
- GPU
- TPU
Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
"""
def __init__(
self,
precision_plugin: PrecisionPlugin,
training_type_plugin: TrainingTypePlugin,
) -> None:
"""
Args:
precision_plugin: the plugin to handle precision-specific parts
training_type_plugin: the plugin to handle different training routines
"""
self.precision_plugin = precision_plugin
self.training_type_plugin = training_type_plugin
self.optimizers: Sequence = []
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers
Args:
trainer: the trainer instance to connect to
model: the model to train
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)
def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)
def start_testing(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_testing(trainer)
def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)
def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()
def post_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
@property
def model(self) -> torch.nn.Module:
"""Returns the model. This can also be a wrapped LightningModule.
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
"""
return self.training_type_plugin.model
@model.setter
def model(self, new_model: torch.nn.Module) -> None:
self.training_type_plugin.model = new_model
@property
def lightning_module(self) -> LightningModule:
"""Returns the pure LightningModule.
To get the potentially wrapped model use :attr:`Accelerator.model`
"""
return self.training_type_plugin.lightning_module
@property
def root_device(self) -> torch.device:
return self.training_type_plugin.root_device
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
"""
pass
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
Args:
batch: The batch of samples to move to the correct device
device: The target device
"""
model = self.lightning_module
if model is not None:
return model._apply_batch_transfer_handler(batch, device)
return move_data_to_device(batch, device)
def on_train_start(self) -> None:
"""Hook to do something upon the training start"""
pass
def training_step(
self,
args: List[Union[Any, int]],
) -> _STEP_OUTPUT_TYPE:
"""The actual training step.
Args:
args: the arguments for the models training step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
"""
args[0] = self.to_device(args[0])
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
def post_training_step(self) -> None:
self.training_type_plugin.post_training_step()
def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual validation step.
Args:
args: the arguments for the models validation step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
batch = self.to_device(args[0])
args[0] = batch
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)
def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual test step.
Args:
args: the arguments for the models test step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
batch = self.to_device(args[0])
args[0] = batch
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)
def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual predict step.
Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
batch = self.to_device(args[0])
args[0] = batch
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)
def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the training step
Args:
output: the output of the training step
"""
return self.training_type_plugin.training_step_end(output)
def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the test step
Args:
output: the output of the test step
"""
return self.training_type_plugin.test_step_end(output)
def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the validation step
Args:
output: the output of the validation step
"""
return self.training_type_plugin.validation_step_end(output)
def backward(
self,
closure_loss: torch.Tensor,
optimizer: Optimizer,
optimizer_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""Forwards backward-calls to the precision plugin.
Args:
closure_loss: a tensor holding the loss value to backpropagate
should_accumulate: whether to accumulate gradients
"""
self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, optimizer_idx)
output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs
)
self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, optimizer_idx)
return output
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
"""performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value
"""
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
)
if make_optimizer_step:
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients"""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""clips all the optimizer parameters to the given value"""
self.precision_plugin.clip_gradients(optimizer, clip_val)
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch
Args:
outputs: the outputs of the training steps
"""
pass
def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass
def setup_optimizers(self, trainer: 'Trainer') -> None:
"""creates optimizers and schedulers
Args:
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing:
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
"""
plugin.connect(model)
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
self.optimizers = optimizers
self.schedulers = schedulers
def to_device(self, batch: Any) -> Any:
"""Pushes the batch to the root device"""
return self.batch_to_device(batch, self.root_device)
@property
def amp_backend(self) -> Optional[LightningEnum]:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
return AMPType.NATIVE
return None
@property
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision
@property
def scaler(self) -> Optional['GradScaler']:
return getattr(self.precision_plugin, 'scaler', None)
@property
def rpc_enabled(self) -> bool:
return self.training_type_plugin.rpc_enabled
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
"""
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
return self.training_type_plugin.on_save(checkpoint)
def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
Args:
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
src: The source rank of which the object will be broadcast from
"""
return self.training_type_plugin.broadcast(obj, src)
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)
@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached within the training type plugin.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
return self.training_type_plugin.results