allow optimizer fx to return 1 or 2 lists
This commit is contained in:
parent
b9e0d841dc
commit
638d79a5a6
|
@ -438,8 +438,10 @@ class Trainer(TrainerIO):
|
|||
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers, self.lr_schedulers = model.configure_optimizers()
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers = model.configure_optimizers()
|
||||
if len(self.optimizers) == 2:
|
||||
self.optimizers, self.lr_schedulers = self.optimizers
|
||||
|
||||
self.__run_pretrain_routine(model)
|
||||
|
||||
|
@ -450,8 +452,10 @@ class Trainer(TrainerIO):
|
|||
def __dp_train(self, model):
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers, self.lr_schedulers = model.configure_optimizers()
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers = model.configure_optimizers()
|
||||
if len(self.optimizers) == 2:
|
||||
self.optimizers, self.lr_schedulers = self.optimizers
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
|
@ -504,8 +508,10 @@ class Trainer(TrainerIO):
|
|||
self.__init_tcp_connection()
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers, self.lr_schedulers = model.configure_optimizers()
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers = model.configure_optimizers()
|
||||
if len(self.optimizers) == 2:
|
||||
self.optimizers, self.lr_schedulers = self.optimizers
|
||||
|
||||
# MODEL
|
||||
# copy model to each gpu
|
||||
|
|
Loading…
Reference in New Issue