Fix `LightningModule.{un,}toggle_model` when only 1 optimizer is used (#12088)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Cai Q.T 2022-02-28 20:41:51 +08:00 committed by GitHub
parent 17bb815d01
commit 01c31ae434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 2 deletions

View File

@ -666,6 +666,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the mid-epoch warning call while resuming training ([#11556](https://github.com/PyTorchLightning/pytorch-lightning/pull/11556))
- Fixed `LightningModule.{un,}toggle_model` when only 1 optimizer is used ([#12088](https://github.com/PyTorchLightning/pytorch-lightning/pull/12088))
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))

View File

@ -1384,7 +1384,7 @@ class LightningModule(
# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
param_requires_grad_state = {}
for opt in self.optimizers(use_pl_optimizer=False):
for opt in self.trainer.optimizers:
for group in opt.param_groups:
for param in group["params"]:
# If a param already appear in param_requires_grad_state, continue
@ -1408,7 +1408,7 @@ class LightningModule(
Args:
optimizer_idx: The index of the optimizer to untoggle.
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
for opt_idx, opt in enumerate(self.trainer.optimizers):
if optimizer_idx != opt_idx:
for group in opt.param_groups:
for param in group["params"]:

View File

@ -87,6 +87,23 @@ def test_property_loggers(tmpdir):
assert model.loggers == [logger]
def test_1_optimizer_toggle_model():
"""Test toggle_model runs when only one optimizer is used."""
model = BoringModel()
trainer = Mock()
model.trainer = trainer
params = model.parameters()
optimizer = torch.optim.SGD(params, lr=0.1)
trainer.optimizers = [optimizer]
assert not model._param_requires_grad_state
# toggle optimizer was failing with a single optimizer
model.toggle_optimizer(optimizer, 0)
assert model._param_requires_grad_state
model.untoggle_optimizer(0)
assert not model._param_requires_grad_state
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):