From b9f581ab874f30ff2e70cf186d054212790ec687 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 3 Jul 2019 15:17:02 -0400 Subject: [PATCH] added single node distdataparallel --- pytorch_lightning/models/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 004cbd4788..916c0fd139 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -282,11 +282,12 @@ class Trainer(TrainerIO): # when GPU is called, spawn off a single worker for each gpu if self.on_gpu: rank = 0 - mp.spawn(self.__dp_train, nprocs=len(self.data_parallel_device_ids), args=(rank, model )) + self.model = model + mp.spawn(self.__dp_train, nprocs=len(self.data_parallel_device_ids), args=(rank)) else: self.__run_pretrain_routine(model) - def __dp_train(self, gpu_nb, proc_rank, model): + def __dp_train(self, gpu_nb, proc_rank): """ Entry point into a DP thread :param gpu_nb: