Optimizer Frequencies logic, and new configure_optimizers (#1269)

* init_optimizers accepts Dict, Sequence[Dict]
and returns optimizer_frequencies.
optimizer_frequencies was added as a member of Trainer.

* Optimizer frequencies logic implemented in training_loop.
Description added to configure_optimizers in LightningModule

* optimizer frequencies tests added to test_gpu

* Fixed formatting for merging PR #1269

* Apply suggestions from code review

* Apply suggestions from code review

Co-Authored-By: Asaf Manor <32155911+asafmanor@users.noreply.github.com>

* Update trainer.py

* Moving get_optimizers_iterable() outside.

* Update note

* Apply suggestions from code review

* formatting

* formatting

* Update CHANGELOG.md

* formatting

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Asaf Manor 2020-03-31 19:41:24 +03:00 committed by GitHub
parent ee68d5ba8e
commit aca8c7e6f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 43 deletions

View File

@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))

View File

@ -914,10 +914,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
Return: any of these 3 options:
- Single optimizer
- List or Tuple - List of optimizers
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers
Return: any of these 5 options:
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
- Tuple of dictionaries as described, with an optional `frequency` key.
Note:
The `frequency` value is an int corresponding to the number of sequential batches
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
There is difference between passing multiple optimizers in a list,
and passing multiple optimizers in dictionaries with a frequency of 1:
In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
Examples:
.. code-block:: python
@ -949,6 +959,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
return [gen_opt, dis_opt], [gen_sched, dis_sched]
# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
)
Note:
Some things to know:

View File

@ -304,7 +304,8 @@ class TrainerDDPMixin(ABC):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
# MODEL
# copy model to each gpu

View File

@ -459,7 +459,8 @@ class TrainerDPMixin(ABC):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
if self.use_amp:
# An example
@ -485,7 +486,8 @@ class TrainerDPMixin(ABC):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
# init 16 bit for TPU
if self.precision == 16:
@ -503,7 +505,8 @@ class TrainerDPMixin(ABC):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
model.cuda(self.root_gpu)

View File

@ -3,7 +3,7 @@ import os
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
import distutils
import torch
@ -354,6 +354,7 @@ class Trainer(
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
@ -710,7 +711,8 @@ class Trainer(
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.run_pretrain_routine(model)
@ -756,31 +758,59 @@ class Trainer(
def init_optimizers(
self,
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
) -> Tuple[List, List]:
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
) -> Tuple[List, List, List]:
# single output, single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []
# two lists, optimizer + lr schedulers
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers
return optimizers, lr_schedulers, []
# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
return [optimizer], lr_schedulers, []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")]
# take only freq wif exists and ot they are defined - not None
optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")]
# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies
# single list or tuple, multiple optimizer
elif isinstance(optimizers, (list, tuple)):
return optimizers, []
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []
# unknown configuration
else:
raise ValueError('Unknown configuration for model optimizers. Output'
'from model.configure_optimizers() should either be:'
'* single output, single torch.optim.Optimizer'
'* single output, list of torch.optim.Optimizer'
'* two outputs, first being a list of torch.optim.Optimizer',
'second being a list of torch.optim.lr_scheduler')
raise ValueError(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
def configure_schedulers(self, schedulers: list):
# Convert each scheduler into dict sturcture with relevant information
@ -966,6 +996,7 @@ class _PatchDataLoader(object):
dataloader: Dataloader object to return when called.
"""
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

View File

@ -197,6 +197,7 @@ class TrainerTrainLoopMixin(ABC):
total_batches: int
truncated_bptt_steps: ...
optimizers: ...
optimizer_frequencies: ...
accumulate_grad_batches: int
use_amp: bool
track_grad_norm: ...
@ -530,8 +531,7 @@ class TrainerTrainLoopMixin(ABC):
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx
# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
for opt_idx, optimizer in self._get_optimizers_iterable():
# make sure only the gradients of the current optimizer's paramaters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.optimizers) > 1:
@ -634,6 +634,19 @@ class TrainerTrainLoopMixin(ABC):
return 0, grad_norm_dic, all_log_metrics
def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.optimizers))
optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.total_batch_idx % optimizers_loop_length
# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.optimizers[opt_idx])]
def run_training_teardown(self):
self.main_progress_bar.close()

View File

@ -82,7 +82,8 @@ def run_model_test(trainer_options, model, on_gpu=True):
if trainer.use_ddp or trainer.use_ddp2:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers())
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
trainer.init_optimizers(pretrained_model.configure_optimizers())
# test HPC loading / saving
trainer.hpc_save(save_dir, logger)

View File

@ -93,30 +93,52 @@ def test_optimizer_return_options():
# single optimizer
opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
optim, lr_sched = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0
scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)
# single optimizer
optim, lr_sched, freq = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0
# opt tuple
opts = (opt_a, opt_b)
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0
assert len(lr_sched) == 0 and len(freq) == 0
# opt list
opts = [opt_a, opt_b]
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0
assert len(lr_sched) == 0 and len(freq) == 0
# opt tuple of lists
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
opts = ([opt_a], [scheduler])
optim, lr_sched = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1
assert optim[0] == opts[0][0] and \
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
frequency=1, reduce_on_plateau=False,
monitor='val_loss')
# opt tuple of 2 lists
opts = ([opt_a], [scheduler_a])
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opts[0][0]
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
# opt single dictionary
opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a}
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
# opt multiple dictionaries with frequencies
opts = (
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
assert freq == [1, 5]
def test_cpu_slurm_save_load(tmpdir):