diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index ae38b1ef6f..f1884a6b7b 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -22,7 +22,7 @@ For advanced research topics like reinforcement learning, sparse coding, or GAN to manually manage the optimization process. To do so, do the following: * Ignore the optimizer_idx argument -* So we can scale the loss automatically for you use self.backward(loss) instead of loss.backward() +* So we can scale the loss automatically for you use self.manual_backward(loss) instead of loss.backward() .. code-block:: python @@ -34,7 +34,7 @@ to manually manage the optimization process. To do so, do the following: loss_a = ... # use self.backward which will also handle scaling the loss when using amp - self.backward(loss_a, opt_g) + self.manual_backward(loss_a, opt_g) opt_g.step() opt_g.zero_grad() @@ -42,8 +42,8 @@ to manually manage the optimization process. To do so, do the following: loss_b = ... # pass in any args that loss.backward() normally takes - self.backward(loss_b, opt_d, retain_graph=True) - self.backward(loss_b, opt_d, retain_graph=True) + self.manual_backward(loss_b, opt_d, retain_graph=True) + self.manual_backward(loss_b, opt_d, retain_graph=True) loss_b.step() loss_b.zero_grad() diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e90512f071..dd4f5787d4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1038,7 +1038,7 @@ class LightningModule( "`configure_optimizers` must be implemented to be used with the Lightning Trainer" ) - def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None: + def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None: """ Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you @@ -1051,10 +1051,31 @@ class LightningModule( (opt_a, opt_b) = self.optimizers() loss = ... # automatically applies scaling, etc... - self.backward(loss, opt_a) + self.manual_backward(loss, opt_a) """ self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs) + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: + """ + Override backward with your own implementation if you need to. + + Args: + loss: Loss is already scaled by accumulated grads + optimizer: Current optimizer being used + optimizer_idx: Index of the current optimizer being used + + Called to perform backward step. + Feel free to override as needed. + The loss passed in has already been scaled for accumulated gradients if requested. + + Example:: + + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() + + """ + loss.backward() + def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): """ Makes sure only the gradients of the current optimizer's parameters are calculated diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 657bd31534..9ef6082478 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -296,7 +296,7 @@ Example:: (opt) = self.optimizers() loss = ... - self.backward(loss, opt) + self.manual_backward(loss, opt) opt.step() opt.zero_grad() @@ -311,12 +311,12 @@ Example:: (opt_a, opt_b) = self.optimizers() gen_loss = ... - self.backward(gen_loss, opt_a) + self.manual_backward(gen_loss, opt_a) opt_a.step() opt_a.zero_grad() disc_loss = ... - self.backward(disc_loss, opt_b) + self.manual_backward(disc_loss, opt_b) opt_b.step() opt_b.zero_grad() diff --git a/tests/trainer/dynamic_args/test_multiple_optimizers.py b/tests/trainer/dynamic_args/test_multiple_optimizers.py index 6d6b9d5cde..23090a7d9d 100644 --- a/tests/trainer/dynamic_args/test_multiple_optimizers.py +++ b/tests/trainer/dynamic_args/test_multiple_optimizers.py @@ -65,13 +65,13 @@ def test_multiple_optimizers_manual(tmpdir): loss_1 = self.step(batch[0]) # fake generator - self.backward(loss_1, opt_a) + self.manual_backward(loss_1, opt_a) opt_a.step() opt_a.zero_grad() # fake discriminator loss_2 = self.step(batch[0]) - self.backward(loss_2, opt_b) + self.manual_backward(loss_2, opt_b) opt_b.step() opt_b.zero_grad() diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 13e3d23d5c..9dc7c1caff 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -23,7 +23,7 @@ def test_multiple_optimizers_manual(tmpdir): if batch_idx > 0: assert torch.all(self.layer.weight.grad == 0) - self.backward(loss_1, opt_a) + self.manual_backward(loss_1, opt_a) opt_a.step() opt_a.zero_grad() assert torch.all(self.layer.weight.grad == 0) @@ -33,8 +33,8 @@ def test_multiple_optimizers_manual(tmpdir): # ensure we forward the correct params to the optimizer # without retain_graph we can't do multiple backward passes - self.backward(loss_2, opt_b, retain_graph=True) - self.backward(loss_2, opt_a, retain_graph=True) + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) assert self.layer.weight.grad is not None opt_b.step() @@ -87,7 +87,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): if batch_idx > 0: assert torch.all(self.layer.weight.grad == 0) - self.backward(loss_1, opt_a) + self.manual_backward(loss_1, opt_a) opt_a.step() opt_a.zero_grad() assert torch.all(self.layer.weight.grad == 0) @@ -97,8 +97,8 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): # ensure we forward the correct params to the optimizer # without retain_graph we can't do multiple backward passes - self.backward(loss_2, opt_b, retain_graph=True) - self.backward(loss_2, opt_a, retain_graph=True) + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) assert self.layer.weight.grad is not None opt_b.step() @@ -157,7 +157,7 @@ def test_multiple_optimizers_manual_apex(tmpdir): if batch_idx > 0: assert torch.all(self.layer.weight.grad == 0) - self.backward(loss_1, opt_a) + self.manual_backward(loss_1, opt_a) opt_a.step() opt_a.zero_grad() assert torch.all(self.layer.weight.grad == 0) @@ -168,8 +168,8 @@ def test_multiple_optimizers_manual_apex(tmpdir): # ensure we forward the correct params to the optimizer # without retain_graph we can't do multiple backward passes - self.backward(loss_2, opt_b, retain_graph=True) - self.backward(loss_2, opt_a, retain_graph=True) + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) assert self.layer.weight.grad is not None opt_b.step()