diff --git a/CHANGELOG.md b/CHANGELOG.md index 221f4ff8ee..a280660229 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,10 +61,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587)) + - Delete the deprecated `TrainerLoggingMixin` class ([#8609](https://github.com/PyTorchLightning/pytorch-lightning/pull/8609)) -- + +- Removed the deprecated `optimizer_idx` from `training_step` as an accepted argument in manual optimization ([#8576](https://github.com/PyTorchLightning/pytorch-lightning/pull/8576)) + - diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 8a129bbacf..152d34cbb0 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -160,19 +160,20 @@ class TrainingBatchLoop(Loop): return len(self.get_active_optimizers(batch_idx)) def _run_optimization( - self, batch_idx: int, split_batch: Any, opt_idx: int = 0, optimizer: Optional[torch.optim.Optimizer] = None + self, + batch_idx: int, + split_batch: Any, + opt_idx: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None, ): """Runs closure (train step + backward) together with optimization if necessary. Args: batch_idx: the index of the current batch split_batch: the current tbptt split of the whole batch - opt_idx: the index of the current optimizer - optimizer: the current optimizer + opt_idx: the index of the current optimizer or `None` in case of manual optimization + optimizer: the current optimizer or `None` in case of manual optimization """ - # TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change - # opt_idx=0 to opt_idx=None in the signature here - # toggle model params self._run_optimization_start(opt_idx, optimizer) @@ -625,10 +626,10 @@ class TrainingBatchLoop(Loop): has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: - self._warning_cache.deprecation( - "`training_step` hook signature has changed in v1.3." - " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5" + raise ValueError( + "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" + " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" + " argument or set `self.automatic_optimization = True`." ) step_kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index cce9176c97..8dbe17e7a0 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -166,26 +166,6 @@ def test_v1_5_0_running_sanity_check(): assert not trainer.running_sanity_check -def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir): - class OldSignatureModel(BoringModel): - def __init__(self): - super().__init__() - self.automatic_optimization = False - - def training_step(self, batch, batch_idx, optimizer_idx): - assert optimizer_idx == 0 - return super().training_step(batch, batch_idx) - - def configure_optimizers(self): - return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)] - - model = OldSignatureModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) - - with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): - trainer.fit(model) - - def test_v1_5_0_model_checkpoint_period(tmpdir): with no_warning_call(DeprecationWarning): ModelCheckpoint(dirpath=tmpdir) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index fa1d2cd381..9ce3358fbe 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -1098,3 +1098,17 @@ def test_multiple_optimizers_logging(precision, tmpdir): assert set(trainer.logged_metrics) == {"epoch", "loss_d", "loss_g"} assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"} + + +def test_manual_optimization_training_step_signature(tmpdir): + """Test that Lightning raises an exception if the training_step signature has an optimier_idx by mistake.""" + + class ConfusedAutomaticManualModel(ManualOptModel): + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + model = ConfusedAutomaticManualModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) + + with pytest.raises(ValueError, match="Your `LightningModule.training_step` signature contains an `optimizer_idx`"): + trainer.fit(model)