From 16f4cc9ff08f38debd0086a0a489b3e16e208c30 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 3 Apr 2020 17:56:19 -0400 Subject: [PATCH] Shubhamagarwal92 master (#1349) * SA: for #958: set torch cuda device when finding root * SA: for #958: removing root gpu hack in trainer/evaluation_loop * SA: setting torch cuda device * comment line too long * check if root gpu exists or available * Incorporating suggestions on #1094 * since root gpu returns none instead of -1 for cpu * undo changes * fixed dp memory thing Co-authored-by: Shubham Agarwal --- pytorch_lightning/trainer/distrib_parts.py | 3 +++ pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index d5217c8951..7abf987d5c 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -526,6 +526,9 @@ class TrainerDPMixin(ABC): if isinstance(device_ids, int): device_ids = list(range(device_ids)) + # set dp device + torch.cuda.set_device(self.root_gpu) + model = LightningDataParallel(model, device_ids=device_ids) self.run_pretrain_routine(model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dca3df6a06..dca3d53928 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -389,6 +389,7 @@ class Trainer( self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) + self.root_device = torch.device("cpu") # tpu state flags self.use_tpu = False