diff --git a/docs/Trainer/hooks.md b/docs/Trainer/hooks.md index f785d909c1..9eb1dbf872 100644 --- a/docs/Trainer/hooks.md +++ b/docs/Trainer/hooks.md @@ -67,6 +67,33 @@ def on_tng_metrics(self, metrics): # do something before validation end ``` +--- +#### optimizer_step +Calls .step() and .zero_grad for each optimizer. +You can override this method to adjust how you do the optimizer step for each optimizer + +Called once per optimizer +```python +# DEFAULT +def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i): + optimizer.step() + optimizer.zero_grad() + +# Alternating schedule for optimizer steps (ie: GANs) +def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i): + # update generator opt every 2 steps + if optimizer_i == 0: + if batch_nb % 2 == 0 : + optimizer.step() + optimizer.zero_grad() + + # update discriminator opt every 4 steps + if optimizer_i == 1: + if batch_nb % 4 == 0 : + optimizer.step() + optimizer.zero_grad() +``` + --- #### on_before_zero_grad Called in the training loop after taking an optimizer step and before zeroing grads.