diff --git a/CHANGELOG.md b/CHANGELOG.md index 02110f6e68..4015fc1829 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) +- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) - Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8cec41da2b..99b01865d9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -914,10 +914,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): If you don't define this method Lightning will automatically use Adam(lr=1e-3) - Return: any of these 3 options: - - Single optimizer - - List or Tuple - List of optimizers - - Two lists - The first list has multiple optimizers, the second a list of LR schedulers + Return: any of these 5 options: + - Single optimizer. + - List or Tuple - List of optimizers. + - Two lists - The first list has multiple optimizers, the second a list of LR schedulers. + - Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key. + - Tuple of dictionaries as described, with an optional `frequency` key. + + Note: + The `frequency` value is an int corresponding to the number of sequential batches + optimized with the specific optimizer. It should be given to none or to all of the optimizers. + There is difference between passing multiple optimizers in a list, + and passing multiple optimizers in dictionaries with a frequency of 1: + In the former case, all optimizers will operate on the given batch in each optimization step. + In the latter, only one optimizer will operate on the given batch at every step. Examples: .. code-block:: python @@ -949,6 +959,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sched, dis_sched] + # example with optimizer frequencies + # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 + # https://arxiv.org/abs/1704.00028 + def configure_optimizers(self): + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + n_critic = 5 + return ( + {'optimizer': dis_opt, 'frequency': n_critic}, + {'optimizer': gen_opt, 'frequency': 1} + ) + Note: Some things to know: diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 95b9a61974..e592308a9e 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -304,7 +304,8 @@ class TrainerDDPMixin(ABC): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ + self.init_optimizers(model.configure_optimizers()) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index fc6007b75d..57e460feb6 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -459,7 +459,8 @@ class TrainerDPMixin(ABC): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ + self.init_optimizers(model.configure_optimizers()) if self.use_amp: # An example @@ -485,7 +486,8 @@ class TrainerDPMixin(ABC): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ + self.init_optimizers(model.configure_optimizers()) # init 16 bit for TPU if self.precision == 16: @@ -503,7 +505,8 @@ class TrainerDPMixin(ABC): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ + self.init_optimizers(model.configure_optimizers()) model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 69b0376ff1..37ecbe67b4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -3,7 +3,7 @@ import os import sys import warnings from argparse import ArgumentParser -from typing import Union, Optional, List, Dict, Tuple, Iterable, Any +from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence import distutils import torch @@ -354,6 +354,7 @@ class Trainer( self.disable_validation = False self.lr_schedulers = [] self.optimizers = None + self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 @@ -710,7 +711,8 @@ class Trainer( # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ + self.init_optimizers(model.configure_optimizers()) self.run_pretrain_routine(model) @@ -756,31 +758,59 @@ class Trainer( def init_optimizers( self, - optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]] - ) -> Tuple[List, List]: + optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]] + ) -> Tuple[List, List, List]: # single output, single optimizer - if isinstance(optimizers, Optimizer): - return [optimizers], [] + if isinstance(optim_conf, Optimizer): + return [optim_conf], [], [] # two lists, optimizer + lr schedulers - elif len(optimizers) == 2 and isinstance(optimizers[0], list): - optimizers, lr_schedulers = optimizers + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): + optimizers, lr_schedulers = optim_conf lr_schedulers = self.configure_schedulers(lr_schedulers) - return optimizers, lr_schedulers + return optimizers, lr_schedulers, [] + + # single dictionary + elif isinstance(optim_conf, dict): + optimizer = optim_conf["optimizer"] + lr_scheduler = optim_conf.get("lr_scheduler", []) + if lr_scheduler: + lr_schedulers = self.configure_schedulers([lr_scheduler]) + return [optimizer], lr_schedulers, [] + + # multiple dictionaries + elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + # take only lr wif exists and ot they are defined - not None + lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")] + # take only freq wif exists and ot they are defined - not None + optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")] + + # clean scheduler list + if lr_schedulers: + lr_schedulers = self.configure_schedulers(lr_schedulers) + # assert that if frequencies are present, they are given for all optimizers + if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): + raise ValueError("A frequency must be given to each optimizer.") + return optimizers, lr_schedulers, optimizer_frequencies # single list or tuple, multiple optimizer - elif isinstance(optimizers, (list, tuple)): - return optimizers, [] + elif isinstance(optim_conf, (list, tuple)): + return list(optim_conf), [], [] # unknown configuration else: - raise ValueError('Unknown configuration for model optimizers. Output' - 'from model.configure_optimizers() should either be:' - '* single output, single torch.optim.Optimizer' - '* single output, list of torch.optim.Optimizer' - '* two outputs, first being a list of torch.optim.Optimizer', - 'second being a list of torch.optim.lr_scheduler') + raise ValueError( + 'Unknown configuration for model optimizers.' + ' Output from `model.configure_optimizers()` should either be:' + ' * single output, single `torch.optim.Optimizer`' + ' * single output, list of `torch.optim.Optimizer`' + ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' + ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' + ' * two outputs, first being a list of `torch.optim.Optimizer` second being' + ' a list of `torch.optim.lr_scheduler`' + ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') def configure_schedulers(self, schedulers: list): # Convert each scheduler into dict sturcture with relevant information @@ -966,6 +996,7 @@ class _PatchDataLoader(object): dataloader: Dataloader object to return when called. """ + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 02abae9b47..2a245ff5fc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -197,6 +197,7 @@ class TrainerTrainLoopMixin(ABC): total_batches: int truncated_bptt_steps: ... optimizers: ... + optimizer_frequencies: ... accumulate_grad_batches: int use_amp: bool track_grad_norm: ... @@ -530,8 +531,7 @@ class TrainerTrainLoopMixin(ABC): for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx - # call training_step once per optimizer - for opt_idx, optimizer in enumerate(self.optimizers): + for opt_idx, optimizer in self._get_optimizers_iterable(): # make sure only the gradients of the current optimizer's paramaters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.optimizers) > 1: @@ -634,6 +634,19 @@ class TrainerTrainLoopMixin(ABC): return 0, grad_norm_dic, all_log_metrics + def _get_optimizers_iterable(self): + if not self.optimizer_frequencies: + # call training_step once per optimizer + return list(enumerate(self.optimizers)) + + optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies) + optimizers_loop_length = optimizer_freq_cumsum[-1] + current_place_in_loop = self.total_batch_idx % optimizers_loop_length + + # find optimzier index by looking for the first {item > current_place} in the cumsum list + opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) + return [(opt_idx, self.optimizers[opt_idx])] + def run_training_teardown(self): self.main_progress_bar.close() diff --git a/tests/base/utils.py b/tests/base/utils.py index ee206d67e5..f1d26f0b83 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -82,7 +82,8 @@ def run_model_test(trainer_options, model, on_gpu=True): if trainer.use_ddp or trainer.use_ddp2: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers()) + trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ + trainer.init_optimizers(pretrained_model.configure_optimizers()) # test HPC loading / saving trainer.hpc_save(save_dir, logger) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 5de29f647d..f783a2d32a 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -93,30 +93,52 @@ def test_optimizer_return_options(): # single optimizer opt_a = torch.optim.Adam(model.parameters(), lr=0.002) opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - optim, lr_sched = trainer.init_optimizers(opt_a) - assert len(optim) == 1 and len(lr_sched) == 0 + scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) + scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) + + # single optimizer + optim, lr_sched, freq = trainer.init_optimizers(opt_a) + assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 # opt tuple opts = (opt_a, opt_b) - optim, lr_sched = trainer.init_optimizers(opts) + optim, lr_sched, freq = trainer.init_optimizers(opts) assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] - assert len(lr_sched) == 0 + assert len(lr_sched) == 0 and len(freq) == 0 # opt list opts = [opt_a, opt_b] - optim, lr_sched = trainer.init_optimizers(opts) + optim, lr_sched, freq = trainer.init_optimizers(opts) assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] - assert len(lr_sched) == 0 + assert len(lr_sched) == 0 and len(freq) == 0 - # opt tuple of lists - scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10) - opts = ([opt_a], [scheduler]) - optim, lr_sched = trainer.init_optimizers(opts) - assert len(optim) == 1 and len(lr_sched) == 1 - assert optim[0] == opts[0][0] and \ - lr_sched[0] == dict(scheduler=scheduler, interval='epoch', - frequency=1, reduce_on_plateau=False, - monitor='val_loss') + # opt tuple of 2 lists + opts = ([opt_a], [scheduler_a]) + optim, lr_sched, freq = trainer.init_optimizers(opts) + assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert optim[0] == opts[0][0] + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + + # opt single dictionary + opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a} + optim, lr_sched, freq = trainer.init_optimizers(opts) + assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + + # opt multiple dictionaries with frequencies + opts = ( + {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, + {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, + ) + optim, lr_sched, freq = trainer.init_optimizers(opts) + assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 + assert optim[0] == opt_a + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + assert freq == [1, 5] def test_cpu_slurm_save_load(tmpdir):