lightning/pytorch_lightning/core/hooks.py

151 lines
4.6 KiB
Python
Raw Normal View History

"""
Hooks
=====
There are cases when you might want to do something different at different parts of the training/validation loop.
To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
**Contributing** If there's a hook you'd like to add, simply:
2020-02-27 12:35:20 +00:00
1. Fork PyTorchLightning.
2020-02-27 12:35:20 +00:00
2. Add the hook :py:mod:`pytorch_lightning.base_module.hooks.py`.
2020-02-27 12:35:20 +00:00
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.
"""
2019-03-31 01:45:16 +00:00
import torch
2019-08-04 18:08:14 +00:00
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
2019-03-31 01:45:16 +00:00
class ModelHooks(torch.nn.Module):
2019-08-07 11:51:55 +00:00
def on_sanity_check_start(self):
"""
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
Called before starting evaluate
.. warning:: will be deprecated.
2019-08-07 11:51:55 +00:00
:return:
"""
def on_train_start(self):
"""Called at the beginning of training before sanity check
:return:
"""
# do something at the start of training
def on_train_end(self):
"""
Called at the end of training before logger experiment is closed
:return:
"""
# do something at the end of training
def on_batch_start(self, batch):
"""Called in the training loop before anything happens for that batch.
:param batch:
:return:
"""
# do something when the batch starts
2019-03-31 01:45:16 +00:00
def on_batch_end(self):
"""Called in the training loop after the batch."""
# do something when the batch ends
2019-03-31 01:45:16 +00:00
def on_epoch_start(self):
"""Called in the training loop at the very beginning of the epoch."""
# do something when the epoch starts
2019-03-31 01:45:16 +00:00
def on_epoch_end(self):
"""Called in the training loop at the very end of the epoch."""
# do something when the epoch ends
2019-03-31 01:45:16 +00:00
def on_pre_performance_check(self):
"""Called at the very beginning of the validation loop."""
# do something before validation starts
2019-03-31 01:45:16 +00:00
def on_post_performance_check(self):
"""Called at the very end of the validation loop."""
# do something before validation end
2019-04-21 16:26:35 +00:00
2019-07-21 22:15:58 +00:00
def on_before_zero_grad(self, optimizer):
"""Called after optimizer.step() and before optimizer.zero_grad()
Called in the training loop after taking an optimizer step and before zeroing grads.
Good place to inspect weight information with weights updated.
for optimizer in optimizers::
2019-07-21 22:15:58 +00:00
optimizer.step()
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad
:param optimizer:
:return:
"""
# do something with the optimizer or inspect it.
2019-07-21 22:15:58 +00:00
2019-07-21 22:23:48 +00:00
def on_after_backward(self):
"""Called after loss.backward() and before optimizers do anything.
2019-07-21 22:23:48 +00:00
:return:
Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information
.. code-block:: python
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.logger.experiment.add_histogram(tag=name, values=grads,
global_step=self.trainer.global_step)
2019-07-21 22:23:48 +00:00
"""
Enable TPU support (#868) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * fix test pkg create (#873) * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Luis Capelo <luiscape@gmail.com> * Fix segmentation example (#876) * removed torchvision model and added custom model * minor fix * Fixed relative imports issue * Fix/typo (#880) * Update greetings.yml * Update greetings.yml * Changelog (#869) * Create CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Update PULL_REQUEST_TEMPLATE.md * Add PR links to Version 0.6.0 in CHANGELOG.md * Add PR links for Unreleased in CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Fixing Function Signatures (#871) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Luis Capelo <luiscape@gmail.com> Co-authored-by: Akshay Kulkarni <akshayk.vnit@gmail.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Shikhar Chauhan <xssChauhan@users.noreply.github.com>
2020-02-17 21:01:20 +00:00
def backward(self, trainer, loss, optimizer, optimizer_idx):
"""Override backward with your own implementation if you need to
Enable TPU support (#868) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * fix test pkg create (#873) * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Luis Capelo <luiscape@gmail.com> * Fix segmentation example (#876) * removed torchvision model and added custom model * minor fix * Fixed relative imports issue * Fix/typo (#880) * Update greetings.yml * Update greetings.yml * Changelog (#869) * Create CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Update PULL_REQUEST_TEMPLATE.md * Add PR links to Version 0.6.0 in CHANGELOG.md * Add PR links for Unreleased in CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Fixing Function Signatures (#871) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Luis Capelo <luiscape@gmail.com> Co-authored-by: Akshay Kulkarni <akshayk.vnit@gmail.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Shikhar Chauhan <xssChauhan@users.noreply.github.com>
2020-02-17 21:01:20 +00:00
:param trainer: Pointer to the trainer
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
:return:
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
.. code-block:: python
def backward(self, use_amp, loss, optimizer):
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
"""
Enable TPU support (#868) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * fix test pkg create (#873) * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Luis Capelo <luiscape@gmail.com> * Fix segmentation example (#876) * removed torchvision model and added custom model * minor fix * Fixed relative imports issue * Fix/typo (#880) * Update greetings.yml * Update greetings.yml * Changelog (#869) * Create CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Update PULL_REQUEST_TEMPLATE.md * Add PR links to Version 0.6.0 in CHANGELOG.md * Add PR links for Unreleased in CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Fixing Function Signatures (#871) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Luis Capelo <luiscape@gmail.com> Co-authored-by: Akshay Kulkarni <akshayk.vnit@gmail.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Shikhar Chauhan <xssChauhan@users.noreply.github.com>
2020-02-17 21:01:20 +00:00
if trainer.precision == 16:
# .backward is not special on 16-bit with TPUs
if not trainer.on_tpu:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()