Fix GAN training. (#603)

* fix dangling gradients

make sure only the gradients of the current optimizer's paramaters are calculated in the training step.

* add note about multiple optimizer gradient update

* Update training_loop.py
This commit is contained in:
Ayberk Aydın 2020-01-14 06:12:04 +03:00 committed by William Falcon
parent 1969c6cc2a
commit 0ae3dd9ed4
2 changed files with 9 additions and 0 deletions

View File

@ -664,6 +664,8 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
.. note:: If you use multiple optimizers, training_step will have an additional `optimizer_idx` parameter.
.. note:: If you use LBFGS lightning handles the closure function automatically for you.
.. note:: If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
Example
-------

View File

@ -454,6 +454,13 @@ class TrainerTrainLoopMixin(ABC):
# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
# make sure only the gradients of the current optimizer's paramaters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
for param in self.get_model().parameters():
param.requires_grad = False
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
# wrap the forward step in a closure so second order methods work
def optimizer_closure():