From 7d97e3e6e47d63a8e219e06ecea8c61c53a23afc Mon Sep 17 00:00:00 2001 From: Phuc Le Date: Wed, 24 Jul 2019 12:12:45 +0700 Subject: [PATCH 1/5] Support any lr_scheduler --- README.md | 1 - .../RequiredTrainerInterface.md | 21 ++++++++++--------- docs/Trainer/Training Loop.md | 11 ---------- docs/Trainer/index.md | 1 - docs/index.md | 1 - .../lightning_module_template.py | 4 ++-- .../sample_model_template/model_template.py | 5 +++-- pytorch_lightning/models/trainer.py | 15 +++---------- pytorch_lightning/root_module/root_module.py | 2 +- 9 files changed, 20 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 85b27c204d..fc975ab54a 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,6 @@ tensorboard --logdir /some/path ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 96522c7eec..f5e36e8f37 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -222,26 +222,27 @@ def validation_end(self, outputs): def configure_optimizers(self) ``` -Set up as many optimizers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. -Lightning will call .backward() and .step() on each one. If you use 16 bit precision it will also handle that. +Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. +Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that. ##### Return -List - List of optimizers +Tuple - List of optimizers and list of schedulers **Example** ``` {.python} # most cases def configure_optimizers(self): - opt = Adam(lr=0.01) - return [opt] + opt = Adam(self.model.parameters(), lr=0.01) + return [opt], [] -# gan example +# gan example, with scheduler for discriminator def configure_optimizers(self): - generator_opt = Adam(lr=0.01) - disriminator_opt = Adam(lr=0.02) - return [generator_opt, disriminator_opt] + generator_opt = Adam(self.model_gen.parameters(), lr=0.01) + disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) + discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) + return [generator_opt, disriminator_opt], [discriminator_sched] ``` --- @@ -427,4 +428,4 @@ def add_model_specific_args(parent_parser, root_dir): parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False) parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False) return parser -``` \ No newline at end of file +``` diff --git a/docs/Trainer/Training Loop.md b/docs/Trainer/Training Loop.md index 2be8da9edd..7a8a7c6058 100644 --- a/docs/Trainer/Training Loop.md +++ b/docs/Trainer/Training Loop.md @@ -11,17 +11,6 @@ Accumulated gradients runs K small batches of size N before doing a backwards pa trainer = Trainer(accumulate_grad_batches=1) ``` ---- -#### Anneal Learning rate -Cut the learning rate by 10 at every epoch listed in this list. -``` {.python} -# DEFAULT (don't anneal) -trainer = Trainer(lr_scheduler_milestones=None) - -# cut LR by 10 at 100, 200, and 300 epochs -trainer = Trainer(lr_scheduler_milestones='100, 200, 300') -``` - --- #### Force training for min or max epochs It can be useful to force training for a minimum number of epochs or limit to a max number diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 1b30da1966..d670ec8b32 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -59,7 +59,6 @@ But of course the fun is in all the advanced things it can do: **Training loop** - [Accumulate gradients](Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](Training%20Loop/#force-disable-early-stop) - [Use multiple optimizers (like GANs)](../Pytorch-lightning/LightningModule/#configure_optimizers) diff --git a/docs/index.md b/docs/index.md index 0e25fa79d5..91d344685b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,7 +60,6 @@ To start a new project define these two files. ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index 0a4dab2692..6e48bb3650 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -174,7 +174,8 @@ class LightningTemplateModel(LightningModule): :return: list of optimizers """ optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return [optimizer] + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + return [optimizer], [scheduler] def __dataloader(self, train): # init data generators @@ -231,7 +232,6 @@ class LightningTemplateModel(LightningModule): # parser.set_defaults(gradient_clip=5.0) # network params - parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) parser.add_argument('--in_features', default=28*28, type=int) parser.add_argument('--out_features', default=10, type=int) parser.add_argument('--hidden_dim', default=50000, type=int) # use 500 for CPU, 50000 for GPU to see speed difference diff --git a/pytorch_lightning/models/sample_model_template/model_template.py b/pytorch_lightning/models/sample_model_template/model_template.py index 10f12c59a1..0c446a11ca 100644 --- a/pytorch_lightning/models/sample_model_template/model_template.py +++ b/pytorch_lightning/models/sample_model_template/model_template.py @@ -128,12 +128,13 @@ class ExampleModel1(LightningModule): # --------------------- def configure_optimizers(self): """ - return whatever optimizers we want here + return whatever optimizers and (optionally) schedulers we want here :return: list of optimizers """ optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') self.optimizers = [optimizer] - return self.optimizers + self.schedulers = [] + return self.optimizers, self.schedulers def __dataloader(self, train): # init data generators diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index bbfafd8933..5da1ead938 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -71,7 +71,6 @@ class Trainer(TrainerIO): train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, log_save_interval=100, add_log_row_interval=10, - lr_scheduler_milestones=None, distributed_backend='dp', use_amp=False, print_nan_grads=False, @@ -104,7 +103,6 @@ class Trainer(TrainerIO): :param val_check_interval: :param log_save_interval: :param add_log_row_interval: - :param lr_scheduler_milestones: :param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel :param use_amp: :param print_nan_grads: @@ -141,7 +139,6 @@ class Trainer(TrainerIO): self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps - self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level self.print_nan_grads = print_nan_grads @@ -444,7 +441,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() self.__run_pretrain_routine(model) @@ -456,7 +453,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() model.cuda(self.data_parallel_device_ids[0]) @@ -507,7 +504,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() # MODEL # copy model to each gpu @@ -587,12 +584,6 @@ class Trainer(TrainerIO): # init training constants self.__layout_bookeeping() - # add lr schedulers - if self.lr_scheduler_milestones is not None: - for optimizer in self.optimizers: - scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones) - self.lr_schedulers.append(scheduler) - # print model summary if self.proc_rank == 0 and self.print_weights_summary: ref_model.summarize() diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index d6d740f039..415934b7da 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -58,7 +58,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): def configure_optimizers(self): """ - Return array of optimizers + Return a list of optimizers and a list of schedulers (could be empty) :return: """ raise NotImplementedError From ba25161dcc1be898ba3443a3c3c8d7d5f77b233c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 08:23:37 -0400 Subject: [PATCH 2/5] Update trainer.py --- pytorch_lightning/models/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5da1ead938..c21cf3be01 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -10,7 +10,6 @@ import re import torch from torch.utils.data.distributed import DistributedSampler -from torch.optim.lr_scheduler import MultiStepLR import torch.multiprocessing as mp import torch.distributed as dist import numpy as np From 5ba0a8fe4cb5fa08913bdab90db496cc11ea62e4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 08:24:56 -0400 Subject: [PATCH 3/5] Update RequiredTrainerInterface.md --- docs/LightningModule/RequiredTrainerInterface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index f5e36e8f37..ff9b84916e 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -234,7 +234,7 @@ Tuple - List of optimizers and list of schedulers ``` {.python} # most cases def configure_optimizers(self): - opt = Adam(self.model.parameters(), lr=0.01) + opt = Adam(self.parameters(), lr=0.01) return [opt], [] # gan example, with scheduler for discriminator From 98fcc1713542c64ff3ac6a7a1dc2b673f35bb1f9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 08:27:14 -0400 Subject: [PATCH 4/5] Update trainer.py --- pytorch_lightning/models/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index c21cf3be01..9c7545103d 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -612,8 +612,9 @@ class Trainer(TrainerIO): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): # update the lr scheduler - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() model = self.__get_model() model.current_epoch = epoch_nb From c7dab0d7856f7035d5c4bdd0b37e550d49428599 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 14:39:04 -0400 Subject: [PATCH 5/5] Update lm_test_module.py --- pytorch_lightning/testing_models/lm_test_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index e33ee53e33..158f8fbf0c 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -190,8 +190,9 @@ class LightningTestModel(LightningModule): return whatever optimizers we want here :return: list of optimizers """ + # try no scheduler for this model (testing purposes) optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return [optimizer] + return [optimizer], [] def __dataloader(self, train): # init data generators