Optimizer Frequencies logic, and new configure_optimizers (#1269)
* init_optimizers accepts Dict, Sequence[Dict] and returns optimizer_frequencies. optimizer_frequencies was added as a member of Trainer. * Optimizer frequencies logic implemented in training_loop. Description added to configure_optimizers in LightningModule * optimizer frequencies tests added to test_gpu * Fixed formatting for merging PR #1269 * Apply suggestions from code review * Apply suggestions from code review Co-Authored-By: Asaf Manor <32155911+asafmanor@users.noreply.github.com> * Update trainer.py * Moving get_optimizers_iterable() outside. * Update note * Apply suggestions from code review * formatting * formatting * Update CHANGELOG.md * formatting * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
ee68d5ba8e
commit
aca8c7e6f3
|
@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
|
||||
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
|
||||
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
|
||||
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
|
||||
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
|
||||
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
|
||||
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
|
||||
|
|
|
@ -914,10 +914,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
|
||||
|
||||
Return: any of these 3 options:
|
||||
- Single optimizer
|
||||
- List or Tuple - List of optimizers
|
||||
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers
|
||||
Return: any of these 5 options:
|
||||
- Single optimizer.
|
||||
- List or Tuple - List of optimizers.
|
||||
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
|
||||
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
|
||||
- Tuple of dictionaries as described, with an optional `frequency` key.
|
||||
|
||||
Note:
|
||||
The `frequency` value is an int corresponding to the number of sequential batches
|
||||
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
|
||||
There is difference between passing multiple optimizers in a list,
|
||||
and passing multiple optimizers in dictionaries with a frequency of 1:
|
||||
In the former case, all optimizers will operate on the given batch in each optimization step.
|
||||
In the latter, only one optimizer will operate on the given batch at every step.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
@ -949,6 +959,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
|
||||
return [gen_opt, dis_opt], [gen_sched, dis_sched]
|
||||
|
||||
# example with optimizer frequencies
|
||||
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
|
||||
# https://arxiv.org/abs/1704.00028
|
||||
def configure_optimizers(self):
|
||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
n_critic = 5
|
||||
return (
|
||||
{'optimizer': dis_opt, 'frequency': n_critic},
|
||||
{'optimizer': gen_opt, 'frequency': 1}
|
||||
)
|
||||
|
||||
Note:
|
||||
|
||||
Some things to know:
|
||||
|
|
|
@ -304,7 +304,8 @@ class TrainerDDPMixin(ABC):
|
|||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
|
||||
self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
# MODEL
|
||||
# copy model to each gpu
|
||||
|
|
|
@ -459,7 +459,8 @@ class TrainerDPMixin(ABC):
|
|||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
|
||||
self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
if self.use_amp:
|
||||
# An example
|
||||
|
@ -485,7 +486,8 @@ class TrainerDPMixin(ABC):
|
|||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
|
||||
self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
# init 16 bit for TPU
|
||||
if self.precision == 16:
|
||||
|
@ -503,7 +505,8 @@ class TrainerDPMixin(ABC):
|
|||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
|
||||
self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
model.cuda(self.root_gpu)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import sys
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
|
||||
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
|
||||
import distutils
|
||||
|
||||
import torch
|
||||
|
@ -354,6 +354,7 @@ class Trainer(
|
|||
self.disable_validation = False
|
||||
self.lr_schedulers = []
|
||||
self.optimizers = None
|
||||
self.optimizer_frequencies = []
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.total_batches = 0
|
||||
|
@ -710,7 +711,8 @@ class Trainer(
|
|||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
|
||||
self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
self.run_pretrain_routine(model)
|
||||
|
||||
|
@ -756,31 +758,59 @@ class Trainer(
|
|||
|
||||
def init_optimizers(
|
||||
self,
|
||||
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
|
||||
) -> Tuple[List, List]:
|
||||
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
|
||||
) -> Tuple[List, List, List]:
|
||||
|
||||
# single output, single optimizer
|
||||
if isinstance(optimizers, Optimizer):
|
||||
return [optimizers], []
|
||||
if isinstance(optim_conf, Optimizer):
|
||||
return [optim_conf], [], []
|
||||
|
||||
# two lists, optimizer + lr schedulers
|
||||
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
|
||||
optimizers, lr_schedulers = optimizers
|
||||
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
|
||||
optimizers, lr_schedulers = optim_conf
|
||||
lr_schedulers = self.configure_schedulers(lr_schedulers)
|
||||
return optimizers, lr_schedulers
|
||||
return optimizers, lr_schedulers, []
|
||||
|
||||
# single dictionary
|
||||
elif isinstance(optim_conf, dict):
|
||||
optimizer = optim_conf["optimizer"]
|
||||
lr_scheduler = optim_conf.get("lr_scheduler", [])
|
||||
if lr_scheduler:
|
||||
lr_schedulers = self.configure_schedulers([lr_scheduler])
|
||||
return [optimizer], lr_schedulers, []
|
||||
|
||||
# multiple dictionaries
|
||||
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
|
||||
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
|
||||
# take only lr wif exists and ot they are defined - not None
|
||||
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")]
|
||||
# take only freq wif exists and ot they are defined - not None
|
||||
optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")]
|
||||
|
||||
# clean scheduler list
|
||||
if lr_schedulers:
|
||||
lr_schedulers = self.configure_schedulers(lr_schedulers)
|
||||
# assert that if frequencies are present, they are given for all optimizers
|
||||
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
|
||||
raise ValueError("A frequency must be given to each optimizer.")
|
||||
return optimizers, lr_schedulers, optimizer_frequencies
|
||||
|
||||
# single list or tuple, multiple optimizer
|
||||
elif isinstance(optimizers, (list, tuple)):
|
||||
return optimizers, []
|
||||
elif isinstance(optim_conf, (list, tuple)):
|
||||
return list(optim_conf), [], []
|
||||
|
||||
# unknown configuration
|
||||
else:
|
||||
raise ValueError('Unknown configuration for model optimizers. Output'
|
||||
'from model.configure_optimizers() should either be:'
|
||||
'* single output, single torch.optim.Optimizer'
|
||||
'* single output, list of torch.optim.Optimizer'
|
||||
'* two outputs, first being a list of torch.optim.Optimizer',
|
||||
'second being a list of torch.optim.lr_scheduler')
|
||||
raise ValueError(
|
||||
'Unknown configuration for model optimizers.'
|
||||
' Output from `model.configure_optimizers()` should either be:'
|
||||
' * single output, single `torch.optim.Optimizer`'
|
||||
' * single output, list of `torch.optim.Optimizer`'
|
||||
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
|
||||
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
|
||||
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
|
||||
' a list of `torch.optim.lr_scheduler`'
|
||||
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
|
||||
|
||||
def configure_schedulers(self, schedulers: list):
|
||||
# Convert each scheduler into dict sturcture with relevant information
|
||||
|
@ -966,6 +996,7 @@ class _PatchDataLoader(object):
|
|||
dataloader: Dataloader object to return when called.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||
self.dataloader = dataloader
|
||||
|
||||
|
|
|
@ -197,6 +197,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
total_batches: int
|
||||
truncated_bptt_steps: ...
|
||||
optimizers: ...
|
||||
optimizer_frequencies: ...
|
||||
accumulate_grad_batches: int
|
||||
use_amp: bool
|
||||
track_grad_norm: ...
|
||||
|
@ -530,8 +531,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
for split_idx, split_batch in enumerate(splits):
|
||||
self.split_idx = split_idx
|
||||
|
||||
# call training_step once per optimizer
|
||||
for opt_idx, optimizer in enumerate(self.optimizers):
|
||||
for opt_idx, optimizer in self._get_optimizers_iterable():
|
||||
# make sure only the gradients of the current optimizer's paramaters are calculated
|
||||
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
||||
if len(self.optimizers) > 1:
|
||||
|
@ -634,6 +634,19 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
return 0, grad_norm_dic, all_log_metrics
|
||||
|
||||
def _get_optimizers_iterable(self):
|
||||
if not self.optimizer_frequencies:
|
||||
# call training_step once per optimizer
|
||||
return list(enumerate(self.optimizers))
|
||||
|
||||
optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
|
||||
optimizers_loop_length = optimizer_freq_cumsum[-1]
|
||||
current_place_in_loop = self.total_batch_idx % optimizers_loop_length
|
||||
|
||||
# find optimzier index by looking for the first {item > current_place} in the cumsum list
|
||||
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
|
||||
return [(opt_idx, self.optimizers[opt_idx])]
|
||||
|
||||
def run_training_teardown(self):
|
||||
self.main_progress_bar.close()
|
||||
|
||||
|
|
|
@ -82,7 +82,8 @@ def run_model_test(trainer_options, model, on_gpu=True):
|
|||
if trainer.use_ddp or trainer.use_ddp2:
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers())
|
||||
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
|
||||
trainer.init_optimizers(pretrained_model.configure_optimizers())
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(save_dir, logger)
|
||||
|
|
|
@ -93,30 +93,52 @@ def test_optimizer_return_options():
|
|||
# single optimizer
|
||||
opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
|
||||
opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
|
||||
optim, lr_sched = trainer.init_optimizers(opt_a)
|
||||
assert len(optim) == 1 and len(lr_sched) == 0
|
||||
scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
|
||||
scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)
|
||||
|
||||
# single optimizer
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opt_a)
|
||||
assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0
|
||||
|
||||
# opt tuple
|
||||
opts = (opt_a, opt_b)
|
||||
optim, lr_sched = trainer.init_optimizers(opts)
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
|
||||
assert len(lr_sched) == 0
|
||||
assert len(lr_sched) == 0 and len(freq) == 0
|
||||
|
||||
# opt list
|
||||
opts = [opt_a, opt_b]
|
||||
optim, lr_sched = trainer.init_optimizers(opts)
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
|
||||
assert len(lr_sched) == 0
|
||||
assert len(lr_sched) == 0 and len(freq) == 0
|
||||
|
||||
# opt tuple of lists
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
|
||||
opts = ([opt_a], [scheduler])
|
||||
optim, lr_sched = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 1 and len(lr_sched) == 1
|
||||
assert optim[0] == opts[0][0] and \
|
||||
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
|
||||
frequency=1, reduce_on_plateau=False,
|
||||
monitor='val_loss')
|
||||
# opt tuple of 2 lists
|
||||
opts = ([opt_a], [scheduler_a])
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
|
||||
assert optim[0] == opts[0][0]
|
||||
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
|
||||
frequency=1, reduce_on_plateau=False, monitor='val_loss')
|
||||
|
||||
# opt single dictionary
|
||||
opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a}
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
|
||||
assert optim[0] == opt_a
|
||||
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
|
||||
frequency=1, reduce_on_plateau=False, monitor='val_loss')
|
||||
|
||||
# opt multiple dictionaries with frequencies
|
||||
opts = (
|
||||
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
|
||||
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
|
||||
)
|
||||
optim, lr_sched, freq = trainer.init_optimizers(opts)
|
||||
assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
|
||||
assert optim[0] == opt_a
|
||||
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
|
||||
frequency=1, reduce_on_plateau=False, monitor='val_loss')
|
||||
assert freq == [1, 5]
|
||||
|
||||
|
||||
def test_cpu_slurm_save_load(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue