lightning/docs/Trainer/hooks.md

132 lines
4.0 KiB
Markdown
Raw Normal View History

2019-07-28 11:50:18 +00:00
# Hooks
[[Github Code](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py)]
There are cases when you might want to do something different at different parts of the training/validation loop.
To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
2019-07-28 12:02:29 +00:00
**Contributing** If there's a hook you'd like to add, simply:
2019-08-01 14:11:26 +00:00
1. Fork PyTorchLightning.
2019-07-28 12:02:29 +00:00
2. Add the hook [here](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py).
3. Add the correct place in the [Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/models/trainer.py) where it should be called.
2019-07-28 11:50:18 +00:00
---
#### on_epoch_start
Called in the training loop at the very beginning of the epoch.
```python
def on_epoch_start(self):
# do something when the epoch starts
```
---
2019-08-07 17:15:42 +00:00
#### on_epoch_end
2019-07-28 11:50:18 +00:00
Called in the training loop at the very end of the epoch.
```python
def on_epoch_end(self):
# do something when the epoch ends
```
---
#### on_batch_start
Called in the training loop before anything happens for that batch.
```python
def on_batch_start(self):
# do something when the batch starts
```
2019-08-07 17:15:42 +00:00
---
#### on_batch_end
Called in the training loop after the batch.
```python
def on_batch_end(self):
# do something when the batch ends
```
2019-07-28 11:50:18 +00:00
---
#### on_pre_performance_check
Called at the very beginning of the validation loop.
```python
def on_pre_performance_check(self):
# do something before validation starts
```
---
#### on_post_performance_check
Called at the very end of the validation loop.
```python
def on_post_performance_check(self):
# do something before validation end
```
---
2019-08-13 15:47:35 +00:00
#### optimizer_step
Calls .step() and .zero_grad for each optimizer.
You can override this method to adjust how you do the optimizer step for each optimizer
Called once per optimizer
```python
# DEFAULT
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
2019-08-13 15:47:35 +00:00
optimizer.step()
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
2019-08-13 15:47:35 +00:00
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_i == 1:
if batch_nb % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
2019-08-13 15:59:33 +00:00
# ...
# add as many optimizers as you want
2019-08-13 15:47:35 +00:00
```
2019-08-16 21:14:31 +00:00
This step allows you to do a lot of non-standard training tricks such as learning-rate warm-up:
```python
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
2019-08-16 21:14:31 +00:00
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step()
optimizer.zero_grad()
```
2019-07-28 11:50:18 +00:00
---
#### on_before_zero_grad
Called in the training loop after taking an optimizer step and before zeroing grads.
Good place to inspect weight information with weights updated.
Called once per optimizer
```python
def on_before_zero_grad(self, optimizer):
# do something with the optimizer or inspect it.
```
---
#### on_after_backward
Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information
```python
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
2019-10-04 22:53:38 +00:00
self.logger.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step)
2019-07-28 11:50:18 +00:00
```