Fix/mismatched toggle optimizer (#7563)
* fix: avoid potential mismatched toggling of optimzier Refs #7405 chore: update CHANGELOG [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix: resolve a confict chore: update changelog * feat: add a test that fails in master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo in tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * fix: change placeholder in optimizer_step from positional args to keyword args Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
2242423b75
commit
01109cdf0c
|
@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
|
||||
|
||||
- Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539))
|
||||
|
||||
|
|
|
@ -724,7 +724,6 @@ class TrainLoop:
|
|||
# -------------------
|
||||
# calculate loss (train step + train step end)
|
||||
# -------------------
|
||||
|
||||
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
|
||||
# automatic_optimization=False: don't block synchronization here
|
||||
with self.block_ddp_sync_behaviour():
|
||||
|
@ -737,6 +736,9 @@ class TrainLoop:
|
|||
else:
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
# revert back to previous state
|
||||
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
|
||||
else:
|
||||
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
|
||||
|
||||
|
@ -837,10 +839,6 @@ class TrainLoop:
|
|||
"training_step returned None. If this was on purpose, ignore this warning..."
|
||||
)
|
||||
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
# revert back to previous state
|
||||
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
|
||||
|
||||
return result
|
||||
|
||||
def _check_finite(self, loss: torch.Tensor) -> None:
|
||||
|
|
|
@ -168,3 +168,68 @@ def test_multiple_optimizers_no_opt_idx_argument(tmpdir):
|
|||
|
||||
with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
|
||||
trainer.fit(TestModel())
|
||||
|
||||
|
||||
def test_custom_optimizer_step_with_multiple_optimizers(tmpdir):
|
||||
"""
|
||||
This tests ensures custom optimizer_step works,
|
||||
even when optimizer.step is not called for a particular optimizer
|
||||
"""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
training_step_called = [0, 0]
|
||||
optimizer_step_called = [0, 0]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_a = torch.nn.Linear(32, 2)
|
||||
self.layer_b = torch.nn.Linear(32, 2)
|
||||
|
||||
def configure_optimizers(self):
|
||||
opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001)
|
||||
opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001)
|
||||
return opt_a, opt_b
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
self.training_step_called[optimizer_idx] += 1
|
||||
x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0])
|
||||
loss = torch.nn.functional.mse_loss(x, torch.ones_like(x))
|
||||
return loss
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
# outputs should be an array with an entry per optimizer
|
||||
assert len(outputs) == 2
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
optimizer_idx,
|
||||
optimizer_closure,
|
||||
**_,
|
||||
):
|
||||
# update first optimizer every step
|
||||
if optimizer_idx == 0:
|
||||
self.optimizer_step_called[optimizer_idx] += 1
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
# update second optimizer every 2 steps
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 2 == 0:
|
||||
self.optimizer_step_called[optimizer_idx] += 1
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=4,
|
||||
max_epochs=1,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert model.training_step_called == [4, 2]
|
||||
assert model.optimizer_step_called == [4, 2]
|
||||
|
|
Loading…
Reference in New Issue