diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index e49fb96ba3..663871d98b 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -111,15 +111,25 @@ class TrainerOptimizersMixin(ABC): def reinit_scheduler_properties(self, optimizers: list, schedulers: list): # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: + scheduler = scheduler['scheduler'] + for optimizer in optimizers: - scheduler = scheduler['scheduler'] # check that we dont mix users optimizers and schedulers if scheduler.optimizer == optimizer: # Find the mro belonging to the base lr scheduler class for i, mro in enumerate(scheduler.__class__.__mro__): - if mro == optim.lr_scheduler._LRScheduler: + if ( + mro == optim.lr_scheduler._LRScheduler + or mro == optim.lr_scheduler.ReduceLROnPlateau + ): idx = i - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + state = scheduler.state_dict() + else: + state = None + + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + if state is not None: + scheduler.load_state_dict(state) class _MockOptimizer(Optimizer):