4.0 KiB
Hooks
There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
Contributing If there's a hook you'd like to add, simply:
- Fork PyTorchLightning.
- Add the hook here.
- Add the correct place in the Trainer where it should be called.
on_epoch_start
Called in the training loop at the very beginning of the epoch.
def on_epoch_start(self):
# do something when the epoch starts
on_epoch_end
Called in the training loop at the very end of the epoch.
def on_epoch_end(self):
# do something when the epoch ends
on_batch_start
Called in the training loop before anything happens for that batch.
def on_batch_start(self):
# do something when the batch starts
on_batch_end
Called in the training loop after the batch.
def on_batch_end(self):
# do something when the batch ends
on_pre_performance_check
Called at the very beginning of the validation loop.
def on_pre_performance_check(self):
# do something before validation starts
on_post_performance_check
Called at the very end of the validation loop.
def on_post_performance_check(self):
# do something before validation end
optimizer_step
Calls .step() and .zero_grad for each optimizer.
You can override this method to adjust how you do the optimizer step for each optimizer
Called once per optimizer
# DEFAULT
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
optimizer.step()
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_i == 1:
if batch_nb % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
# ...
# add as many optimizers as you want
This step allows you to do a lot of non-standard training tricks such as learning-rate warm-up:
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step()
optimizer.zero_grad()
on_before_zero_grad
Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
Called once per optimizer
def on_before_zero_grad(self, optimizer):
# do something with the optimizer or inspect it.
on_after_backward
Called in the training loop after model.backward() This is the ideal place to inspect or log gradient information
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.logger.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step)