allow optimizer fx to return 1 or 2 lists

This commit is contained in:
williamFalcon 2019-07-28 06:33:58 -07:00
parent b9e0d841dc
commit 638d79a5a6
1 changed files with 12 additions and 6 deletions

View File

@ -438,8 +438,10 @@ class Trainer(TrainerIO):
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers, self.lr_schedulers = model.configure_optimizers()
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
self.__run_pretrain_routine(model)
@ -450,8 +452,10 @@ class Trainer(TrainerIO):
def __dp_train(self, model):
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers, self.lr_schedulers = model.configure_optimizers()
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
model.cuda(self.data_parallel_device_ids[0])
@ -504,8 +508,10 @@ class Trainer(TrainerIO):
self.__init_tcp_connection()
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers, self.lr_schedulers = model.configure_optimizers()
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
# MODEL
# copy model to each gpu