From bafdeca42f746aac59b4f0c1103264d7bff556db Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Tue, 21 Apr 2020 14:29:15 -0400 Subject: [PATCH] Replace GPU device idx with current process index (#1541) --- pytorch_lightning/trainer/distrib_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index d67ae26229..dadff04a11 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -327,7 +327,7 @@ class TrainerDDPMixin(ABC): # MODEL # copy model to each gpu if self.on_gpu: - self.root_gpu = self.data_parallel_device_ids[process_idx] + self.root_gpu = process_idx torch.cuda.set_device(self.root_gpu) model.cuda(self.root_gpu)