lightning/pytorch_lightning/core/hooks.py

676 lines
22 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Union
import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
try:
from apex import amp
except ImportError:
amp = None
class ModelHooks:
def setup(self, stage: str):
"""
Called at the beginning of fit and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Args:
stage: either 'fit' or 'test'
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
"""
def teardown(self, stage: str):
"""
Called at the end of fit and test.
Args:
stage: either 'fit' or 'test'
"""
def on_fit_start(self):
"""
Called at the very beginning of fit.
If on DDP it is called on every process
"""
def on_fit_end(self):
"""
Called at the very end of fit.
If on DDP it is called on every process
"""
def on_train_start(self) -> None:
"""
Called at the beginning of training before sanity check.
"""
# do something at the start of training
def on_train_end(self) -> None:
"""
Called at the end of training before logger experiment is closed.
"""
# do something at the end of training
def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the start of the pretrain routine
def on_pretrain_routine_end(self) -> None:
"""
Called at the end of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the end of the pretrain routine
def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch starts
def on_train_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop after the batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch ends
def on_validation_model_eval(
self
) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()
def on_validation_model_train(
self
) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop before anything happens for that batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch starts
def on_validation_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop after the batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch ends
def on_test_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop before anything happens for that batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch starts
def on_test_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop after the batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch ends
def on_test_model_eval(
self
) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
def on_test_model_train(
self
) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
def on_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_start` instead)
"""
# do something when the batch starts
def on_batch_end(self) -> None:
"""
Called in the training loop after the batch.
.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_end` instead)
"""
# do something when the batch ends
def on_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_train_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_train_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_validation_epoch_start(self) -> None:
"""
Called in the validation loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_test_epoch_start(self) -> None:
"""
Called in the test loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""
Called after optimizer.step() and before optimizer.zero_grad().
Called in the training loop after taking an optimizer step and before zeroing grads.
Good place to inspect weight information with weights updated.
This is where it is called::
for optimizer in optimizers:
optimizer.step()
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad()
Args:
optimizer: The optimizer for which grads should be zeroed.
"""
# do something with the optimizer or inspect it.
def on_after_backward(self) -> None:
"""
Called in the training loop after loss.backward() and before optimizers do anything.
This is the ideal place to inspect or log gradient information.
Example::
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.logger.experiment.add_histogram(tag=name, values=grads,
global_step=self.trainer.global_step)
"""
def backward(
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
) -> None:
"""
Override backward with your own implementation if you need to.
Args:
trainer: Pointer to the trainer
loss: Loss is already scaled by accumulated grads
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
Example::
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
"""
loss.backward()
def amp_scale_loss(
self,
unscaled_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
amp_backend: AMPType,
):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)
return scaled_loss
class DataHooks:
def prepare_data(self) -> None:
"""
Use this to download and prepare data.
.. warning:: DO NOT set state to the model (use `setup` instead)
since this is NOT called on every GPU in DDP/TPU
Example::
def prepare_data(self):
# good
download_data()
tokenize()
etc()
# bad
self.split = data_split
self.some_state = some_other_state()
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):
1. Once per node. This is the default and is only called on LOCAL_RANK=0.
2. Once in total. Only called on GLOBAL_RANK=0.
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)
# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)
This is called before requesting the dataloaders:
.. code-block:: python
model.prepare_data()
if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
"""
def train_dataloader(self) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Example:
.. code-block:: python
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=True
)
return loader
"""
rank_zero_warn(
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
)
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Example:
.. code-block:: python
def test_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def test_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note:
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
this method.
Note:
In the case where you return multiple test dataloaders, the :meth:`test_step`
will have an argument ``dataloader_idx`` which matches the order here.
"""
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Examples:
.. code-block:: python
def val_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False,
transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note:
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
implement this method.
Note:
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
will have an argument ``dataloader_idx`` which matches the order here.
"""
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return move_data_to_device(batch, device)
class CheckpointHooks:
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning to restore your model.
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
Args:
checkpoint: Loaded checkpoint
Example:
.. code-block:: python
def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note:
Lightning auto-restores global step, epoch, and train state including amp scaling.
There is no need for you to restore anything regarding training.
"""
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning when saving a checkpoint to give you a chance to store anything
else you might want to save.
Args:
checkpoint: Checkpoint to be saved
Example:
.. code-block:: python
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note:
Lightning saves all aspects of training (epoch, global step, etc...)
including amp scaling.
There is no need for you to store anything about training.
"""