remove support for optimizer_idx in the training_step for manual optimization (#8576)

This commit is contained in:
Adrian Wälchli 2021-07-29 10:30:45 +02:00 committed by GitHub
parent 9c80727b8c
commit 7901d297d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 31 deletions

View File

@ -61,10 +61,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))
- Delete the deprecated `TrainerLoggingMixin` class ([#8609](https://github.com/PyTorchLightning/pytorch-lightning/pull/8609))
-
- Removed the deprecated `optimizer_idx` from `training_step` as an accepted argument in manual optimization ([#8576](https://github.com/PyTorchLightning/pytorch-lightning/pull/8576))
-

View File

@ -160,19 +160,20 @@ class TrainingBatchLoop(Loop):
return len(self.get_active_optimizers(batch_idx))
def _run_optimization(
self, batch_idx: int, split_batch: Any, opt_idx: int = 0, optimizer: Optional[torch.optim.Optimizer] = None
self,
batch_idx: int,
split_batch: Any,
opt_idx: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
):
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
batch_idx: the index of the current batch
split_batch: the current tbptt split of the whole batch
opt_idx: the index of the current optimizer
optimizer: the current optimizer
opt_idx: the index of the current optimizer or `None` in case of manual optimization
optimizer: the current optimizer or `None` in case of manual optimization
"""
# TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change
# opt_idx=0 to opt_idx=None in the signature here
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
@ -625,10 +626,10 @@ class TrainingBatchLoop(Loop):
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
self._warning_cache.deprecation(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5"
raise ValueError(
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
" argument or set `self.automatic_optimization = True`."
)
step_kwargs["optimizer_idx"] = opt_idx
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:

View File

@ -166,26 +166,6 @@ def test_v1_5_0_running_sanity_check():
assert not trainer.running_sanity_check
def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):
class OldSignatureModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx, optimizer_idx):
assert optimizer_idx == 0
return super().training_step(batch, batch_idx)
def configure_optimizers(self):
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]
model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
trainer.fit(model)
def test_v1_5_0_model_checkpoint_period(tmpdir):
with no_warning_call(DeprecationWarning):
ModelCheckpoint(dirpath=tmpdir)

View File

@ -1098,3 +1098,17 @@ def test_multiple_optimizers_logging(precision, tmpdir):
assert set(trainer.logged_metrics) == {"epoch", "loss_d", "loss_g"}
assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"}
def test_manual_optimization_training_step_signature(tmpdir):
"""Test that Lightning raises an exception if the training_step signature has an optimier_idx by mistake."""
class ConfusedAutomaticManualModel(ManualOptModel):
def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)
model = ConfusedAutomaticManualModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
with pytest.raises(ValueError, match="Your `LightningModule.training_step` signature contains an `optimizer_idx`"):
trainer.fit(model)