updated optimizer_step docs
This commit is contained in:
parent
acc16565c5
commit
53ec3bc5bc
|
@ -67,6 +67,33 @@ def on_tng_metrics(self, metrics):
|
|||
# do something before validation end
|
||||
```
|
||||
|
||||
---
|
||||
#### optimizer_step
|
||||
Calls .step() and .zero_grad for each optimizer.
|
||||
You can override this method to adjust how you do the optimizer step for each optimizer
|
||||
|
||||
Called once per optimizer
|
||||
```python
|
||||
# DEFAULT
|
||||
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Alternating schedule for optimizer steps (ie: GANs)
|
||||
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i):
|
||||
# update generator opt every 2 steps
|
||||
if optimizer_i == 0:
|
||||
if batch_nb % 2 == 0 :
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# update discriminator opt every 4 steps
|
||||
if optimizer_i == 1:
|
||||
if batch_nb % 4 == 0 :
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
---
|
||||
#### on_before_zero_grad
|
||||
Called in the training loop after taking an optimizer step and before zeroing grads.
|
||||
|
|
Loading…
Reference in New Issue