lightning/pytorch_lightning/core/hooks.py

138 lines
4.1 KiB
Python
Raw Normal View History

"""
# Hooks
There are cases when you might want to do something different at different parts of the training/validation loop.
To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
**Contributing** If there's a hook you'd like to add, simply:
1. Fork PyTorchLightning.
2. Add the hook :py:mod:`pytorch_lightning.base_module.hooks.py`.
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.
"""
2019-03-31 01:45:16 +00:00
import torch
2019-08-04 18:08:14 +00:00
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
2019-03-31 01:45:16 +00:00
class ModelHooks(torch.nn.Module):
2019-08-07 11:51:55 +00:00
def on_sanity_check_start(self):
"""
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
Called before starting evaluate
2019-08-07 11:51:55 +00:00
:return:
"""
pass
def on_batch_start(self, batch):
"""Called in the training loop before anything happens for that batch.
:param batch:
:return:
"""
# do something when the batch starts
2019-03-31 01:45:16 +00:00
pass
def on_batch_end(self):
"""Called in the training loop after the batch."""
# do something when the batch ends
2019-03-31 01:45:16 +00:00
pass
def on_epoch_start(self):
"""Called in the training loop at the very beginning of the epoch."""
# do something when the epoch starts
2019-03-31 01:45:16 +00:00
pass
def on_epoch_end(self):
"""Called in the training loop at the very end of the epoch."""
# do something when the epoch ends
2019-03-31 01:45:16 +00:00
pass
def on_pre_performance_check(self):
"""Called at the very beginning of the validation loop."""
# do something before validation starts
2019-03-31 01:45:16 +00:00
pass
def on_post_performance_check(self):
"""Called at the very end of the validation loop."""
# do something before validation end
2019-03-31 01:45:16 +00:00
pass
2019-04-21 16:26:35 +00:00
2019-07-21 22:15:58 +00:00
def on_before_zero_grad(self, optimizer):
"""Called after optimizer.step() and before optimizer.zero_grad()
Called in the training loop after taking an optimizer step and before zeroing grads.
Good place to inspect weight information with weights updated.
for optimizer in optimizers::
2019-07-21 22:15:58 +00:00
optimizer.step()
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad
:param optimizer:
:return:
"""
# do something with the optimizer or inspect it.
2019-07-21 22:15:58 +00:00
pass
2019-07-21 22:23:48 +00:00
def on_after_backward(self):
"""Called after loss.backward() and before optimizers do anything.
2019-07-21 22:23:48 +00:00
:return:
Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information
.. code-block:: python
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.logger.experiment.add_histogram(tag=name, values=grads,
global_step=self.trainer.global_step)
2019-07-21 22:23:48 +00:00
"""
pass
def backward(self, use_amp, loss, optimizer):
"""Override backward with your own implementation if you need to
:param use_amp: Whether amp was requested or not
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:return:
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
.. code-block:: python
def backward(self, use_amp, loss, optimizer):
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
"""
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()