From 8f6b7a2b4fea9b7bd0b873f5973e6364b3981412 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 22 May 2020 13:19:37 +0200 Subject: [PATCH] Fix user warning produced by apex + scheduler combination (#1873) * fix user error produced by apex + scheduler combination * add changelog * added reinit to every configure_apex call * fix styling Co-authored-by: Nicki Skafte --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/distrib_data_parallel.py | 1 + pytorch_lightning/trainer/distrib_parts.py | 3 +++ pytorch_lightning/trainer/optimizers.py | 13 +++++++++++++ 4 files changed, 19 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8392d65a47..3f893a298b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) + ## [0.7.6] - 2020-05-16 ### Added diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index ae55d24470..3844896ba2 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -372,6 +372,7 @@ class TrainerDDPMixin(ABC): if self.use_amp and not self.use_native_amp: model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers + self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) # DDP2 uses all GPUs on the machine if self.distributed_backend == 'ddp': diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index a08c7a1c07..dabe72f27d 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -497,6 +497,7 @@ class TrainerDPMixin(ABC): # An example model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers + self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) self.run_pretrain_routine(model) @@ -559,6 +560,7 @@ class TrainerDPMixin(ABC): f' We recommend you switch to ddp if you want to use amp') else: model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) + self.reinit_scheduler_properties(optimizers, self.lr_schedulers) # create list of device ids device_ids = self.data_parallel_device_ids @@ -599,6 +601,7 @@ class TrainerDPMixin(ABC): # An example model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers + self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(model.state_dict(), root_rank=0) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index ea33e3e0bb..e49fb96ba3 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -108,6 +108,19 @@ class TrainerOptimizersMixin(ABC): 'is a invalid input.') return lr_schedulers + def reinit_scheduler_properties(self, optimizers: list, schedulers: list): + # Reinitialize optimizer.step properties added by schedulers + for scheduler in schedulers: + for optimizer in optimizers: + scheduler = scheduler['scheduler'] + # check that we dont mix users optimizers and schedulers + if scheduler.optimizer == optimizer: + # Find the mro belonging to the base lr scheduler class + for i, mro in enumerate(scheduler.__class__.__mro__): + if mro == optim.lr_scheduler._LRScheduler: + idx = i + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + class _MockOptimizer(Optimizer): """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`