From 638d79a5a63a1ef5bf66294a0d02e340836b3555 Mon Sep 17 00:00:00 2001 From: williamFalcon Date: Sun, 28 Jul 2019 06:33:58 -0700 Subject: [PATCH] allow optimizer fx to return 1 or 2 lists --- pytorch_lightning/models/trainer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ac67d83dad..5ce253c29c 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -438,8 +438,10 @@ class Trainer(TrainerIO): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers self.__run_pretrain_routine(model) @@ -450,8 +452,10 @@ class Trainer(TrainerIO): def __dp_train(self, model): # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers model.cuda(self.data_parallel_device_ids[0]) @@ -504,8 +508,10 @@ class Trainer(TrainerIO): self.__init_tcp_connection() # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers # MODEL # copy model to each gpu