Fix user warning produced by apex + scheduler combination (#1873)

* fix user error produced by apex + scheduler combination

* add changelog

* added reinit to every configure_apex call

* fix styling

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
Nicki Skafte 2020-05-22 13:19:37 +02:00 committed by GitHub
parent d610f3bb53
commit 8f6b7a2b4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 0 deletions

View File

@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
## [0.7.6] - 2020-05-16
### Added

View File

@ -372,6 +372,7 @@ class TrainerDDPMixin(ABC):
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':

View File

@ -497,6 +497,7 @@ class TrainerDPMixin(ABC):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
self.run_pretrain_routine(model)
@ -559,6 +560,7 @@ class TrainerDPMixin(ABC):
f' We recommend you switch to ddp if you want to use amp')
else:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
# create list of device ids
device_ids = self.data_parallel_device_ids
@ -599,6 +601,7 @@ class TrainerDPMixin(ABC):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

View File

@ -108,6 +108,19 @@ class TrainerOptimizersMixin(ABC):
'is a invalid input.')
return lr_schedulers
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
for optimizer in optimizers:
scheduler = scheduler['scheduler']
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro == optim.lr_scheduler._LRScheduler:
idx = i
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`