remove support for optimizer_idx in the training_step for manual optimization (#8576)
This commit is contained in:
parent
9c80727b8c
commit
7901d297d3
|
@ -61,10 +61,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Removed the deprecated `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))
|
- Removed the deprecated `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- Delete the deprecated `TrainerLoggingMixin` class ([#8609](https://github.com/PyTorchLightning/pytorch-lightning/pull/8609))
|
- Delete the deprecated `TrainerLoggingMixin` class ([#8609](https://github.com/PyTorchLightning/pytorch-lightning/pull/8609))
|
||||||
|
|
||||||
|
|
||||||
-
|
|
||||||
|
- Removed the deprecated `optimizer_idx` from `training_step` as an accepted argument in manual optimization ([#8576](https://github.com/PyTorchLightning/pytorch-lightning/pull/8576))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
-
|
-
|
||||||
|
|
|
@ -160,19 +160,20 @@ class TrainingBatchLoop(Loop):
|
||||||
return len(self.get_active_optimizers(batch_idx))
|
return len(self.get_active_optimizers(batch_idx))
|
||||||
|
|
||||||
def _run_optimization(
|
def _run_optimization(
|
||||||
self, batch_idx: int, split_batch: Any, opt_idx: int = 0, optimizer: Optional[torch.optim.Optimizer] = None
|
self,
|
||||||
|
batch_idx: int,
|
||||||
|
split_batch: Any,
|
||||||
|
opt_idx: Optional[int] = None,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
):
|
):
|
||||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_idx: the index of the current batch
|
batch_idx: the index of the current batch
|
||||||
split_batch: the current tbptt split of the whole batch
|
split_batch: the current tbptt split of the whole batch
|
||||||
opt_idx: the index of the current optimizer
|
opt_idx: the index of the current optimizer or `None` in case of manual optimization
|
||||||
optimizer: the current optimizer
|
optimizer: the current optimizer or `None` in case of manual optimization
|
||||||
"""
|
"""
|
||||||
# TODO(@awaelchli): In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change
|
|
||||||
# opt_idx=0 to opt_idx=None in the signature here
|
|
||||||
|
|
||||||
# toggle model params
|
# toggle model params
|
||||||
self._run_optimization_start(opt_idx, optimizer)
|
self._run_optimization_start(opt_idx, optimizer)
|
||||||
|
|
||||||
|
@ -625,10 +626,10 @@ class TrainingBatchLoop(Loop):
|
||||||
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
||||||
if has_opt_idx_in_train_step:
|
if has_opt_idx_in_train_step:
|
||||||
if not lightning_module.automatic_optimization:
|
if not lightning_module.automatic_optimization:
|
||||||
self._warning_cache.deprecation(
|
raise ValueError(
|
||||||
"`training_step` hook signature has changed in v1.3."
|
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
|
||||||
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
|
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
|
||||||
" the old signature will be removed in v1.5"
|
" argument or set `self.automatic_optimization = True`."
|
||||||
)
|
)
|
||||||
step_kwargs["optimizer_idx"] = opt_idx
|
step_kwargs["optimizer_idx"] = opt_idx
|
||||||
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
||||||
|
|
|
@ -166,26 +166,6 @@ def test_v1_5_0_running_sanity_check():
|
||||||
assert not trainer.running_sanity_check
|
assert not trainer.running_sanity_check
|
||||||
|
|
||||||
|
|
||||||
def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):
|
|
||||||
class OldSignatureModel(BoringModel):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
||||||
assert optimizer_idx == 0
|
|
||||||
return super().training_step(batch, batch_idx)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]
|
|
||||||
|
|
||||||
model = OldSignatureModel()
|
|
||||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
|
||||||
|
|
||||||
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
|
|
||||||
trainer.fit(model)
|
|
||||||
|
|
||||||
|
|
||||||
def test_v1_5_0_model_checkpoint_period(tmpdir):
|
def test_v1_5_0_model_checkpoint_period(tmpdir):
|
||||||
with no_warning_call(DeprecationWarning):
|
with no_warning_call(DeprecationWarning):
|
||||||
ModelCheckpoint(dirpath=tmpdir)
|
ModelCheckpoint(dirpath=tmpdir)
|
||||||
|
|
|
@ -1098,3 +1098,17 @@ def test_multiple_optimizers_logging(precision, tmpdir):
|
||||||
|
|
||||||
assert set(trainer.logged_metrics) == {"epoch", "loss_d", "loss_g"}
|
assert set(trainer.logged_metrics) == {"epoch", "loss_d", "loss_g"}
|
||||||
assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"}
|
assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual_optimization_training_step_signature(tmpdir):
|
||||||
|
"""Test that Lightning raises an exception if the training_step signature has an optimier_idx by mistake."""
|
||||||
|
|
||||||
|
class ConfusedAutomaticManualModel(ManualOptModel):
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
return super().training_step(batch, batch_idx)
|
||||||
|
|
||||||
|
model = ConfusedAutomaticManualModel()
|
||||||
|
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Your `LightningModule.training_step` signature contains an `optimizer_idx`"):
|
||||||
|
trainer.fit(model)
|
||||||
|
|
Loading…
Reference in New Issue