From 9445c800b032b2f71bc290baf992ed22eb04739d Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Tue, 18 Aug 2020 19:28:35 -0400 Subject: [PATCH] set device to root gpu (#3042) --- pytorch_lightning/accelerators/gpu_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 881a7061df..ea0057dcc1 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import AMPType @@ -32,6 +33,7 @@ class GPUBackend(object): # call setup self.trainer.call_setup_hook(model) + torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) # CHOOSE OPTIMIZER