diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index fb99dda22a..3ede697c76 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -288,9 +288,6 @@ class Trainer(TrainerIO): # MODEL TRAINING # ----------------------------- def fit(self, model): - # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus - self.optimizers = model.configure_optimizers() # when using gpus, first thing we do is spawn a new process between each worker # applies to single gpu, multi-gpu and multi-nodes @@ -298,6 +295,10 @@ class Trainer(TrainerIO): self.experiment = self.experiment.get_meta_copy() mp.spawn(self.dp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) else: + # CHOOSE OPTIMIZER + # filter out the weights that were done on gpu so we can load on good old cpus + self.optimizers = model.configure_optimizers() + # run through amp wrapper if self.use_amp: # An example @@ -338,10 +339,16 @@ class Trainer(TrainerIO): ip = self.__get_root_node_ip(self.proc_rank, self.nb_gpu_nodes) self.__init_tcp_connection(ip) + # CHOOSE OPTIMIZER + # filter out the weights that were done on gpu so we can load on good old cpus + self.optimizers = model.configure_optimizers() + + # MODEL # copy model to each gpu torch.cuda.set_device(gpu_nb) model.cuda(gpu_nb) + # AMP # run through amp wrapper before going to distributed DP if self.use_amp: # An example