diff --git a/README.md b/README.md index afbefb16e7..549a48504d 100644 --- a/README.md +++ b/README.md @@ -276,7 +276,6 @@ tensorboard --logdir /some/path ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 059c54260f..ef7c2e3008 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -225,26 +225,27 @@ def validation_end(self, outputs): def configure_optimizers(self) ``` -Set up as many optimizers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. -Lightning will call .backward() and .step() on each one. If you use 16 bit precision it will also handle that. +Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. +Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that. ##### Return -List - List of optimizers +Tuple - List of optimizers and list of schedulers **Example** ``` {.python} # most cases def configure_optimizers(self): - opt = Adam(lr=0.01) - return [opt] + opt = Adam(self.parameters(), lr=0.01) + return [opt], [] -# gan example +# gan example, with scheduler for discriminator def configure_optimizers(self): - generator_opt = Adam(lr=0.01) - disriminator_opt = Adam(lr=0.02) - return [generator_opt, disriminator_opt] + generator_opt = Adam(self.model_gen.parameters(), lr=0.01) + disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) + discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) + return [generator_opt, disriminator_opt], [discriminator_sched] ``` --- @@ -431,4 +432,4 @@ def add_model_specific_args(parent_parser, root_dir): parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False) parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False) return parser -``` \ No newline at end of file +``` diff --git a/docs/Trainer/Training Loop.md b/docs/Trainer/Training Loop.md index e1ff90a484..c2b6dc35a0 100644 --- a/docs/Trainer/Training Loop.md +++ b/docs/Trainer/Training Loop.md @@ -11,17 +11,6 @@ Accumulated gradients runs K small batches of size N before doing a backwards pa trainer = Trainer(accumulate_grad_batches=1) ``` ---- -#### Anneal Learning rate -Cut the learning rate by 10 at every epoch listed in this list. -``` {.python} -# DEFAULT (don't anneal) -trainer = Trainer(lr_scheduler_milestones=None) - -# cut LR by 10 at 100, 200, and 300 epochs -trainer = Trainer(lr_scheduler_milestones='100, 200, 300') -``` - --- #### Force training for min or max epochs It can be useful to force training for a minimum number of epochs or limit to a max number diff --git a/docs/index.md b/docs/index.md index bcd7ca2159..631d940d59 100644 --- a/docs/index.md +++ b/docs/index.md @@ -66,7 +66,6 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file. ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index 608e534e0f..ca8f604b23 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -163,7 +163,8 @@ class LightningTemplateModel(LightningModule): :return: list of optimizers """ optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return [optimizer] + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + return [optimizer], [scheduler] def __dataloader(self, train): # init data generators @@ -220,7 +221,6 @@ class LightningTemplateModel(LightningModule): # parser.set_defaults(gradient_clip=5.0) # network params - parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) parser.add_argument('--in_features', default=28*28, type=int) parser.add_argument('--out_features', default=10, type=int) parser.add_argument('--hidden_dim', default=50000, type=int) # use 500 for CPU, 50000 for GPU to see speed difference diff --git a/pytorch_lightning/models/sample_model_template/model_template.py b/pytorch_lightning/models/sample_model_template/model_template.py new file mode 100644 index 0000000000..0c446a11ca --- /dev/null +++ b/pytorch_lightning/models/sample_model_template/model_template.py @@ -0,0 +1,204 @@ +import torch.nn as nn +import numpy as np +from pytorch_lightning import LightningModule +from test_tube import HyperOptArgumentParser +from torchvision.datasets import MNIST +import torchvision.transforms as transforms +import torch +import torch.nn.functional as F + + +class ExampleModel1(LightningModule): + """ + Sample model to show how to define a template + """ + + def __init__(self, hparams): + # init superclass + super(ExampleModel1, self).__init__(hparams) + + self.batch_size = hparams.batch_size + + # build model + self.__build_model() + + # --------------------- + # MODEL SETUP + # --------------------- + def __build_model(self): + """ + Layout model + :return: + """ + self.c_d1 = nn.Linear(in_features=self.hparams.in_features, out_features=self.hparams.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) + self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, out_features=self.hparams.out_features) + + # --------------------- + # TRAINING + # --------------------- + def forward(self, x): + x = self.c_d1(x) + x = F.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.log_softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def training_step(self, data_batch): + """ + Called inside the training loop + :param data_batch: + :return: + """ + # forward pass + x, y = data_batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + tqdm_dic = {'jefe': 1} + return loss_val, tqdm_dic + + def validation_step(self, data_batch): + """ + Called inside the validation loop + :param data_batch: + :return: + """ + x, y = data_batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + + output = {'y_hat': y_hat, 'val_loss': loss_val.item(), 'val_acc': val_acc} + return output + + def validation_end(self, outputs): + """ + Called at the end of validation to aggregate outputs + :param outputs: list of individual outputs of each validation step + :return: + """ + val_loss_mean = 0 + accs = [] + for output in outputs: + val_loss_mean += output['val_loss'] + accs.append(output['val_acc']) + + val_loss_mean /= len(outputs) + tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': np.mean(accs)} + return tqdm_dic + + def update_tng_log_metrics(self, logs): + return logs + + # --------------------- + # MODEL SAVING + # --------------------- + def get_save_dict(self): + checkpoint = { + 'state_dict': self.state_dict(), + } + + return checkpoint + + def load_model_specific(self, checkpoint): + self.load_state_dict(checkpoint['state_dict']) + pass + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + return whatever optimizers and (optionally) schedulers we want here + :return: list of optimizers + """ + optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') + self.optimizers = [optimizer] + self.schedulers = [] + return self.optimizers, self.schedulers + + def __dataloader(self, train): + # init data generators + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) + + dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True) + + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + shuffle=True + ) + + return loader + + @data_loader + def tng_dataloader(self): + if self._tng_dataloader is None: + try: + self._tng_dataloader = self.__dataloader(train=True) + except Exception as e: + print(e) + raise e + return self._tng_dataloader + + @property + def val_dataloader(self): + if self._val_dataloader is None: + try: + self._val_dataloader = self.__dataloader(train=False) + except Exception as e: + print(e) + raise e + return self._val_dataloader + + @property + def test_dataloader(self): + if self._test_dataloader is None: + try: + self._test_dataloader = self.__dataloader(train=False) + except Exception as e: + print(e) + raise e + return self._test_dataloader + + @staticmethod + def add_model_specific_args(parent_parser): + parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) + + # param overwrites + # parser.set_defaults(gradient_clip=5.0) + + # network params + parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) + parser.add_argument('--in_features', default=28*28) + parser.add_argument('--hidden_dim', default=500) + parser.add_argument('--out_features', default=10) + + # data + parser.add_argument('--data_root', default='/Users/williamfalcon/Developer/personal/research_lib/research_proj/datasets/mnist', type=str) + + # training params (opt) + parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005], + tunable=False) + parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False) + parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False) + return parser diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 4b741c9f12..ac67d83dad 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -10,7 +10,6 @@ import re import torch from torch.utils.data.distributed import DistributedSampler -from torch.optim.lr_scheduler import MultiStepLR import torch.multiprocessing as mp import torch.distributed as dist import numpy as np @@ -71,7 +70,6 @@ class Trainer(TrainerIO): train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, log_save_interval=100, add_log_row_interval=10, - lr_scheduler_milestones=None, distributed_backend='dp', use_amp=False, print_nan_grads=False, @@ -104,7 +102,6 @@ class Trainer(TrainerIO): :param val_check_interval: :param log_save_interval: :param add_log_row_interval: - :param lr_scheduler_milestones: :param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel :param use_amp: :param print_nan_grads: @@ -141,7 +138,6 @@ class Trainer(TrainerIO): self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps - self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level self.print_nan_grads = print_nan_grads @@ -443,7 +439,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() self.__run_pretrain_routine(model) @@ -455,7 +451,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() model.cuda(self.data_parallel_device_ids[0]) @@ -509,7 +505,7 @@ class Trainer(TrainerIO): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() + self.optimizers, self.lr_schedulers = model.configure_optimizers() # MODEL # copy model to each gpu @@ -589,12 +585,6 @@ class Trainer(TrainerIO): # init training constants self.__layout_bookeeping() - # add lr schedulers - if self.lr_scheduler_milestones is not None: - for optimizer in self.optimizers: - scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones) - self.lr_schedulers.append(scheduler) - # print model summary if self.proc_rank == 0 and self.print_weights_summary: ref_model.summarize() @@ -628,8 +618,9 @@ class Trainer(TrainerIO): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): # update the lr scheduler - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() model = self.__get_model() model.current_epoch = epoch_nb diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index b49dd3af90..c860685eea 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -58,7 +58,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): def configure_optimizers(self): """ - Return array of optimizers + Return a list of optimizers and a list of schedulers (could be empty) :return: """ raise NotImplementedError diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 39fcbb4c3d..a143e51e7e 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -179,8 +179,9 @@ class LightningTestModel(LightningModule): return whatever optimizers we want here :return: list of optimizers """ + # try no scheduler for this model (testing purposes) optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return [optimizer] + return [optimizer], [] def __dataloader(self, train): # init data generators