set dp as default backend
This commit is contained in:
parent
c163caf8cb
commit
e02857fcce
|
@ -363,9 +363,6 @@ class Trainer(TrainerIO):
|
||||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||||
self.optimizers = model.configure_optimizers()
|
self.optimizers = model.configure_optimizers()
|
||||||
|
|
||||||
# attach model to DP
|
|
||||||
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
|
|
||||||
|
|
||||||
# run through amp wrapper
|
# run through amp wrapper
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
# An example
|
# An example
|
||||||
|
@ -374,6 +371,9 @@ class Trainer(TrainerIO):
|
||||||
)
|
)
|
||||||
self.optimizers = optimizers
|
self.optimizers = optimizers
|
||||||
|
|
||||||
|
if self.on_gpu:
|
||||||
|
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||||
|
|
||||||
self.__run_pretrain_routine(model)
|
self.__run_pretrain_routine(model)
|
||||||
|
|
||||||
def ddp_train(self, gpu_nb, model):
|
def ddp_train(self, gpu_nb, model):
|
||||||
|
|
Loading…
Reference in New Issue