From 53ec3bc5bcda61bf4f319d14402c5322e2c25aec Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 13 Aug 2019 11:47:35 -0400 Subject: [PATCH] updated optimizer_step docs --- docs/Trainer/hooks.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/Trainer/hooks.md b/docs/Trainer/hooks.md index f785d909c1..9eb1dbf872 100644 --- a/docs/Trainer/hooks.md +++ b/docs/Trainer/hooks.md @@ -67,6 +67,33 @@ def on_tng_metrics(self, metrics): # do something before validation end ``` +--- +#### optimizer_step +Calls .step() and .zero_grad for each optimizer. +You can override this method to adjust how you do the optimizer step for each optimizer + +Called once per optimizer +```python +# DEFAULT +def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i): + optimizer.step() + optimizer.zero_grad() + +# Alternating schedule for optimizer steps (ie: GANs) +def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i): + # update generator opt every 2 steps + if optimizer_i == 0: + if batch_nb % 2 == 0 : + optimizer.step() + optimizer.zero_grad() + + # update discriminator opt every 4 steps + if optimizer_i == 1: + if batch_nb % 4 == 0 : + optimizer.step() + optimizer.zero_grad() +``` + --- #### on_before_zero_grad Called in the training loop after taking an optimizer step and before zeroing grads.