From 6bc616d78f13c9921f3a08f7c71229b81be8b5ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 05:26:15 +0200 Subject: [PATCH] fix display bug (#7395) --- pytorch_lightning/accelerators/gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index d14b7cbeb9..03303edfc5 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -36,7 +36,7 @@ class GPUAccelerator(Accelerator): """ if "cuda" not in str(self.root_device): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") - self.set_nvidia_flags() + self.set_nvidia_flags(trainer.local_rank) torch.cuda.set_device(self.root_device) return super().setup(trainer, model) @@ -55,12 +55,12 @@ class GPUAccelerator(Accelerator): torch.cuda.empty_cache() @staticmethod - def set_nvidia_flags() -> None: + def set_nvidia_flags(local_rank: int) -> None: # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) - _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") def to_device(self, batch: Any) -> Any: # no need to transfer batch to device in DP mode