lightning/docs/Trainer/hooks.md

142 lines
4.2 KiB
Markdown
Raw Normal View History

2019-07-28 11:50:18 +00:00
# Hooks
[[Github Code](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py)]
There are cases when you might want to do something different at different parts of the training/validation loop.
To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
2019-07-28 12:02:29 +00:00
**Contributing** If there's a hook you'd like to add, simply:
2019-08-01 14:11:26 +00:00
1. Fork PyTorchLightning.
2019-07-28 12:02:29 +00:00
2. Add the hook [here](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py).
3. Add the correct place in the [Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/models/trainer.py) where it should be called.
2019-07-28 11:50:18 +00:00
---
#### on_epoch_start
Called in the training loop at the very beginning of the epoch.
```python
def on_epoch_start(self):
# do something when the epoch starts
```
---
2019-08-07 17:15:42 +00:00
#### on_epoch_end
2019-07-28 11:50:18 +00:00
Called in the training loop at the very end of the epoch.
```python
def on_epoch_end(self):
# do something when the epoch ends
```
---
#### on_batch_start
Called in the training loop before anything happens for that batch.
```python
def on_batch_start(self):
# do something when the batch starts
```
2019-08-07 17:15:42 +00:00
---
#### on_batch_end
Called in the training loop after the batch.
```python
def on_batch_end(self):
# do something when the batch ends
```
2019-07-28 11:50:18 +00:00
---
#### on_pre_performance_check
Called at the very beginning of the validation loop.
```python
def on_pre_performance_check(self):
# do something before validation starts
```
---
#### on_post_performance_check
Called at the very end of the validation loop.
```python
def on_post_performance_check(self):
# do something before validation end
```
---
#### on_tng_metrics
Called in the training loop, right before metrics are logged.
Although you can log at any time by using self.experiment, you can use
this callback to modify what will be logged.
```python
def on_tng_metrics(self, metrics):
# do something before validation end
```
2019-08-13 15:47:35 +00:00
---
#### optimizer_step
Calls .step() and .zero_grad for each optimizer.
You can override this method to adjust how you do the optimizer step for each optimizer
Called once per optimizer
```python
# DEFAULT
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i):
optimizer.step()
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i):
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_i == 1:
if batch_nb % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
2019-08-13 15:59:33 +00:00
# ...
# add as many optimizers as you want
2019-08-13 15:47:35 +00:00
```
2019-08-16 21:14:31 +00:00
This step allows you to do a lot of non-standard training tricks such as learning-rate warm-up:
```python
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step()
optimizer.zero_grad()
```
2019-07-28 11:50:18 +00:00
---
#### on_before_zero_grad
Called in the training loop after taking an optimizer step and before zeroing grads.
Good place to inspect weight information with weights updated.
Called once per optimizer
```python
def on_before_zero_grad(self, optimizer):
# do something with the optimizer or inspect it.
```
---
#### on_after_backward
Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information
```python
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step)
```