diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index b84e53c042..62af5f27dc 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -73,7 +73,7 @@ class GPUAccelerator(Accelerator): """ if _TORCH_GREATER_EQUAL_1_8: return torch.cuda.memory_stats(device) - return _get_nvidia_gpu_stats(device) + return get_nvidia_gpu_stats(device) def teardown(self) -> None: super().teardown() @@ -85,7 +85,7 @@ class GPUAccelerator(Accelerator): return torch.cuda.device_count() -def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: +def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: @@ -98,6 +98,10 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: FileNotFoundError: If nvidia-smi installation not found """ + nvidia_smi_path = shutil.which("nvidia-smi") + if nvidia_smi_path is None: + raise FileNotFoundError("nvidia-smi: command not found") + gpu_stat_metrics = [ ("utilization.gpu", "%"), ("memory.used", "MB"), @@ -111,9 +115,6 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: gpu_query = ",".join(gpu_stat_keys) gpu_id = _get_gpu_id(device.index) - nvidia_smi_path = shutil.which("nvidia-smi") - if nvidia_smi_path is None: - raise FileNotFoundError("nvidia-smi: command not found") result = subprocess.run( [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], encoding="utf-8", @@ -130,10 +131,7 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: s = result.stdout.strip() stats = [_to_float(x) for x in s.split(", ")] - - gpu_stats = {} - for i, (x, unit) in enumerate(gpu_stat_metrics): - gpu_stats[f"{x} ({unit})"] = stats[i] + gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)} return gpu_stats