diff --git a/CHANGELOG.md b/CHANGELOG.md index 44914061b8..9179711b5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) - Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 84d69765c7..a555146875 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -724,7 +724,6 @@ class TrainLoop: # ------------------- # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): @@ -737,6 +736,9 @@ class TrainLoop: else: if self.trainer.lightning_module.automatic_optimization: self.optimizer_step(optimizer, opt_idx, batch_idx, closure) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) @@ -837,10 +839,6 @@ class TrainLoop: "training_step returned None. If this was on purpose, ignore this warning..." ) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - return result def _check_finite(self, loss: torch.Tensor) -> None: diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 24b32c8725..aba3b53248 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -168,3 +168,68 @@ def test_multiple_optimizers_no_opt_idx_argument(tmpdir): with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'): trainer.fit(TestModel()) + + +def test_custom_optimizer_step_with_multiple_optimizers(tmpdir): + """ + This tests ensures custom optimizer_step works, + even when optimizer.step is not called for a particular optimizer + """ + + class TestModel(BoringModel): + training_step_called = [0, 0] + optimizer_step_called = [0, 0] + + def __init__(self): + super().__init__() + self.layer_a = torch.nn.Linear(32, 2) + self.layer_b = torch.nn.Linear(32, 2) + + def configure_optimizers(self): + opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001) + opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001) + return opt_a, opt_b + + def training_step(self, batch, batch_idx, optimizer_idx): + self.training_step_called[optimizer_idx] += 1 + x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0]) + loss = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + **_, + ): + # update first optimizer every step + if optimizer_idx == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + # update second optimizer every 2 steps + if optimizer_idx == 1: + if batch_idx % 2 == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + model = TestModel() + model.val_dataloader = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + assert model.training_step_called == [4, 2] + assert model.optimizer_step_called == [4, 2]