diff --git a/CHANGELOG.md b/CHANGELOG.md index e1793269db..8e7cc82ac1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208)) +- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284)) + + ## [1.3.8] - 2021-07-01 ### Fixed diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index cf892219e3..9b803a2790 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -204,20 +204,17 @@ class TrainingBatchLoop(Loop): else: if self.trainer.lightning_module.automatic_optimization: self._optimizer_step(optimizer, opt_idx, batch_idx, closure) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens) - if not result: - # user decided to skip optimization - return result - - # update running loss + reset accumulated loss + if result: + # if no result, user decided to skip optimization + # otherwise update running loss + reset accumulated loss self._update_running_loss(result.loss) + self._process_closure_result(result) - self._process_closure_result(result) + # untoggle model params + self._run_optimization_end(opt_idx) return result def _training_step_and_backward_closure( @@ -509,6 +506,11 @@ class TrainingBatchLoop(Loop): model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) + def _run_optimization_end(self, opt_idx: int) -> None: + if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.untoggle_optimizer(opt_idx) + @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]: """ diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index f05305c785..09c7b3a711 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -197,7 +197,7 @@ def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir): max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, - accumulate_grad_batches=1, + accumulate_grad_batches=2, limit_val_batches=0, ) trainer.fit(model) @@ -331,7 +331,7 @@ def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir): max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, - accumulate_grad_batches=1, + accumulate_grad_batches=2, ) trainer.fit(model)