remove support for optimizer_idx in the training_step for manual optimization (#8576)
This commit is contained in:
parent
9c80727b8c
commit
7901d297d3
|
@ -61,10 +61,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed the deprecated `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))
|
||||
|
||||
|
||||
|
||||
- Delete the deprecated `TrainerLoggingMixin` class ([#8609](https://github.com/PyTorchLightning/pytorch-lightning/pull/8609))
|
||||
|
||||
|
||||
-
|
||||
|
||||
- Removed the deprecated `optimizer_idx` from `training_step` as an accepted argument in manual optimization ([#8576](https://github.com/PyTorchLightning/pytorch-lightning/pull/8576))
|
||||
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -160,19 +160,20 @@ class TrainingBatchLoop(Loop):
|
|||
return len(self.get_active_optimizers(batch_idx))
|
||||
|
||||
def _run_optimization(
|
||||
self, batch_idx: int, split_batch: Any, opt_idx: int = 0, optimizer: Optional[torch.optim.Optimizer] = None
|
||||
self,
|
||||
batch_idx: int,
|
||||
split_batch: Any,
|
||||
opt_idx: Optional[int] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
split_batch: the current tbptt split of the whole batch
|
||||
opt_idx: the index of the current optimizer
|
||||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer or `None` in case of manual optimization
|
||||
optimizer: the current optimizer or `None` in case of manual optimization
|
||||
"""
|
||||
# TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change
|
||||
# opt_idx=0 to opt_idx=None in the signature here
|
||||
|
||||
# toggle model params
|
||||
self._run_optimization_start(opt_idx, optimizer)
|
||||
|
||||
|
@ -625,10 +626,10 @@ class TrainingBatchLoop(Loop):
|
|||
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
||||
if has_opt_idx_in_train_step:
|
||||
if not lightning_module.automatic_optimization:
|
||||
self._warning_cache.deprecation(
|
||||
"`training_step` hook signature has changed in v1.3."
|
||||
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
|
||||
" the old signature will be removed in v1.5"
|
||||
raise ValueError(
|
||||
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
|
||||
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
|
||||
" argument or set `self.automatic_optimization = True`."
|
||||
)
|
||||
step_kwargs["optimizer_idx"] = opt_idx
|
||||
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
||||
|
|
|
@ -166,26 +166,6 @@ def test_v1_5_0_running_sanity_check():
|
|||
assert not trainer.running_sanity_check
|
||||
|
||||
|
||||
def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):
|
||||
class OldSignatureModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
assert optimizer_idx == 0
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]
|
||||
|
||||
model = OldSignatureModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
||||
|
||||
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_5_0_model_checkpoint_period(tmpdir):
|
||||
with no_warning_call(DeprecationWarning):
|
||||
ModelCheckpoint(dirpath=tmpdir)
|
||||
|
|
|
@ -1098,3 +1098,17 @@ def test_multiple_optimizers_logging(precision, tmpdir):
|
|||
|
||||
assert set(trainer.logged_metrics) == {"epoch", "loss_d", "loss_g"}
|
||||
assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"}
|
||||
|
||||
|
||||
def test_manual_optimization_training_step_signature(tmpdir):
|
||||
"""Test that Lightning raises an exception if the training_step signature has an optimier_idx by mistake."""
|
||||
|
||||
class ConfusedAutomaticManualModel(ManualOptModel):
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
model = ConfusedAutomaticManualModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
||||
|
||||
with pytest.raises(ValueError, match="Your `LightningModule.training_step` signature contains an `optimizer_idx`"):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue