diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index d67ae26229..dadff04a11 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -327,7 +327,7 @@ class TrainerDDPMixin(ABC): # MODEL # copy model to each gpu if self.on_gpu: - self.root_gpu = self.data_parallel_device_ids[process_idx] + self.root_gpu = process_idx torch.cuda.set_device(self.root_gpu) model.cuda(self.root_gpu)