lightning/docs/Trainer/hooks.md

2.7 KiB

Hooks

[Github Code]

There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.

Contributing If there's a hook you'd like to add, simply:

  1. Fork PyTorchLightning.
  2. Add the hook here.
  3. Add the correct place in the Trainer where it should be called.

on_epoch_start

Called in the training loop at the very beginning of the epoch.

def on_epoch_start(self):
    # do something when the epoch starts

on_batch_end

Called in the training loop at the very end of the epoch.

def on_epoch_end(self):
    # do something when the epoch ends 

on_batch_start

Called in the training loop before anything happens for that batch.

def on_batch_start(self):
    # do something when the batch starts

on_pre_performance_check

Called at the very beginning of the validation loop.

def on_pre_performance_check(self):
    # do something before validation starts 

on_post_performance_check

Called at the very end of the validation loop.

def on_post_performance_check(self):
    # do something before validation end

on_tng_metrics

Called in the training loop, right before metrics are logged. Although you can log at any time by using self.experiment, you can use this callback to modify what will be logged.

def on_tng_metrics(self, metrics):
    # do something before validation end

on_before_zero_grad

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

Called once per optimizer

def on_before_zero_grad(self, optimizer):
    # do something with the optimizer or inspect it. 

on_after_backward

Called in the training loop after model.backward() This is the ideal place to inspect or log gradient information

def on_after_backward(self):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        params = self.state_dict()
        for k, v in params.items():
            grads = v
            name = k
            self.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step)