diff --git a/CHANGELOG.md b/CHANGELOG.md index da46315b6f..190771e8f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -189,6 +189,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) +- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519)) + + ## [1.1.4] - 2021-01-12 ### Added diff --git a/pytorch_lightning/accelerators/legacy/dp_accelerator.py b/pytorch_lightning/accelerators/legacy/dp_accelerator.py index 5bee429597..9c1cd7aa7d 100644 --- a/pytorch_lightning/accelerators/legacy/dp_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/dp_accelerator.py @@ -146,30 +146,6 @@ class DataParallelAccelerator(Accelerator): output = output.mean() return output - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): - """ - Reinitialize optimizer.step properties added by schedulers - """ - for scheduler in schedulers: - scheduler = scheduler['scheduler'] - - for optimizer in optimizers: - # check that we dont mix users optimizers and schedulers - if scheduler.optimizer == optimizer: - # Find the mro belonging to the base lr scheduler class - for i, mro in enumerate(scheduler.__class__.__mro__): - is_regular_scheduler = optim.lr_scheduler._LRScheduler - is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau - if is_regular_scheduler or is_lr_reduce_on_plateau: - idx = i - state = scheduler.state_dict() - else: - state = None - - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) - if state is not None: - scheduler.load_state_dict(state) - def get_reference_model(self, model) -> LightningModule: if isinstance(model, torch.nn.DataParallel): model = model.module diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index a00c8b5fbf..6772dcc645 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -149,6 +149,7 @@ class TrainerOptimizersMixin(ABC): # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: scheduler = scheduler['scheduler'] + state = None for optimizer in optimizers: # check that we dont mix users optimizers and schedulers @@ -156,14 +157,13 @@ class TrainerOptimizersMixin(ABC): # Find the mro belonging to the base lr scheduler class for i, mro in enumerate(scheduler.__class__.__mro__): if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau): - idx = i state = scheduler.state_dict() - else: - state = None + scheduler.__class__.__mro__[i].__init__(scheduler, optimizer) + scheduler.load_state_dict(state) + break - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) if state is not None: - scheduler.load_state_dict(state) + break class _MockOptimizer(Optimizer): diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index c9f6ea05ad..509e60ceca 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -16,6 +16,7 @@ from unittest import mock import pytest import torch +from torch import optim import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils @@ -188,9 +189,15 @@ def test_amp_without_apex(tmpdir): @pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") def test_amp_with_apex(tmpdir): """Check calling apex scaling in training.""" + class CustomModel(EvalModelTemplate): + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.SGD(self.parameters(), lr=self.learning_rate) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - model = EvalModelTemplate() - + model = CustomModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -201,4 +208,7 @@ def test_amp_with_apex(tmpdir): assert str(trainer.amp_backend) == "AMPType.APEX" trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert trainer.dev_debugger.count_events('AMP') == 10 + assert trainer.dev_debugger.count_events('AMP') == 20 + + assert isinstance(trainer.lr_schedulers[0]['scheduler'].optimizer, optim.Adam) + assert isinstance(trainer.lr_schedulers[1]['scheduler'].optimizer, optim.SGD)