Fix user warning produced by apex + scheduler combination (#1873)
* fix user error produced by apex + scheduler combination * add changelog * added reinit to every configure_apex call * fix styling Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
parent
d610f3bb53
commit
8f6b7a2b4f
|
@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
|
||||
|
||||
## [0.7.6] - 2020-05-16
|
||||
|
||||
### Added
|
||||
|
|
|
@ -372,6 +372,7 @@ class TrainerDDPMixin(ABC):
|
|||
if self.use_amp and not self.use_native_amp:
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.optimizers = optimizers
|
||||
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
|
||||
|
||||
# DDP2 uses all GPUs on the machine
|
||||
if self.distributed_backend == 'ddp':
|
||||
|
|
|
@ -497,6 +497,7 @@ class TrainerDPMixin(ABC):
|
|||
# An example
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.optimizers = optimizers
|
||||
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
|
||||
|
||||
self.run_pretrain_routine(model)
|
||||
|
||||
|
@ -559,6 +560,7 @@ class TrainerDPMixin(ABC):
|
|||
f' We recommend you switch to ddp if you want to use amp')
|
||||
else:
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
|
||||
|
||||
# create list of device ids
|
||||
device_ids = self.data_parallel_device_ids
|
||||
|
@ -599,6 +601,7 @@ class TrainerDPMixin(ABC):
|
|||
# An example
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.optimizers = optimizers
|
||||
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
|
||||
|
||||
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
|
||||
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
||||
|
|
|
@ -108,6 +108,19 @@ class TrainerOptimizersMixin(ABC):
|
|||
'is a invalid input.')
|
||||
return lr_schedulers
|
||||
|
||||
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
|
||||
# Reinitialize optimizer.step properties added by schedulers
|
||||
for scheduler in schedulers:
|
||||
for optimizer in optimizers:
|
||||
scheduler = scheduler['scheduler']
|
||||
# check that we dont mix users optimizers and schedulers
|
||||
if scheduler.optimizer == optimizer:
|
||||
# Find the mro belonging to the base lr scheduler class
|
||||
for i, mro in enumerate(scheduler.__class__.__mro__):
|
||||
if mro == optim.lr_scheduler._LRScheduler:
|
||||
idx = i
|
||||
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
|
||||
|
||||
|
||||
class _MockOptimizer(Optimizer):
|
||||
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
|
||||
|
|
Loading…
Reference in New Issue