From c275e1fc91df4d351799b633e9df08e010094bfe Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 25 Jun 2020 09:21:41 -0400 Subject: [PATCH] swaps lr sched order (#2356) * swaps lr sched order * Update optimizers.py * added amdim encoder choice --- pytorch_lightning/trainer/optimizers.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index e49fb96ba3..663871d98b 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -111,15 +111,25 @@ class TrainerOptimizersMixin(ABC): def reinit_scheduler_properties(self, optimizers: list, schedulers: list): # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: + scheduler = scheduler['scheduler'] + for optimizer in optimizers: - scheduler = scheduler['scheduler'] # check that we dont mix users optimizers and schedulers if scheduler.optimizer == optimizer: # Find the mro belonging to the base lr scheduler class for i, mro in enumerate(scheduler.__class__.__mro__): - if mro == optim.lr_scheduler._LRScheduler: + if ( + mro == optim.lr_scheduler._LRScheduler + or mro == optim.lr_scheduler.ReduceLROnPlateau + ): idx = i - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + state = scheduler.state_dict() + else: + state = None + + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + if state is not None: + scheduler.load_state_dict(state) class _MockOptimizer(Optimizer):