diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ac67d83dad..5ce253c29c 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -438,8 +438,10 @@ class Trainer(TrainerIO): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers self.__run_pretrain_routine(model) @@ -450,8 +452,10 @@ class Trainer(TrainerIO): def __dp_train(self, model): # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers model.cuda(self.data_parallel_device_ids[0]) @@ -504,8 +508,10 @@ class Trainer(TrainerIO): self.__init_tcp_connection() # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers, self.lr_schedulers = model.configure_optimizers() + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers # MODEL # copy model to each gpu