Make `_get_nvidia_gpu_stats` public (#10406)

This commit is contained in:
Rohit Gupta 2021-11-19 22:22:24 +05:30 committed by GitHub
parent 17a8290ca7
commit ff8ac6e2e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 9 deletions

View File

@ -73,7 +73,7 @@ class GPUAccelerator(Accelerator):
"""
if _TORCH_GREATER_EQUAL_1_8:
return torch.cuda.memory_stats(device)
return _get_nvidia_gpu_stats(device)
return get_nvidia_gpu_stats(device)
def teardown(self) -> None:
super().teardown()
@ -85,7 +85,7 @@ class GPUAccelerator(Accelerator):
return torch.cuda.device_count()
def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Args:
@ -98,6 +98,10 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
FileNotFoundError:
If nvidia-smi installation not found
"""
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
raise FileNotFoundError("nvidia-smi: command not found")
gpu_stat_metrics = [
("utilization.gpu", "%"),
("memory.used", "MB"),
@ -111,9 +115,6 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
gpu_query = ",".join(gpu_stat_keys)
gpu_id = _get_gpu_id(device.index)
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
raise FileNotFoundError("nvidia-smi: command not found")
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
@ -130,10 +131,7 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
s = result.stdout.strip()
stats = [_to_float(x) for x in s.split(", ")]
gpu_stats = {}
for i, (x, unit) in enumerate(gpu_stat_metrics):
gpu_stats[f"{x} ({unit})"] = stats[i]
gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)}
return gpu_stats