Make `_get_nvidia_gpu_stats` public (#10406)
This commit is contained in:
parent
17a8290ca7
commit
ff8ac6e2e1
|
@ -73,7 +73,7 @@ class GPUAccelerator(Accelerator):
|
|||
"""
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
return torch.cuda.memory_stats(device)
|
||||
return _get_nvidia_gpu_stats(device)
|
||||
return get_nvidia_gpu_stats(device)
|
||||
|
||||
def teardown(self) -> None:
|
||||
super().teardown()
|
||||
|
@ -85,7 +85,7 @@ class GPUAccelerator(Accelerator):
|
|||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
||||
def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
||||
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
|
||||
|
||||
Args:
|
||||
|
@ -98,6 +98,10 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
|||
FileNotFoundError:
|
||||
If nvidia-smi installation not found
|
||||
"""
|
||||
nvidia_smi_path = shutil.which("nvidia-smi")
|
||||
if nvidia_smi_path is None:
|
||||
raise FileNotFoundError("nvidia-smi: command not found")
|
||||
|
||||
gpu_stat_metrics = [
|
||||
("utilization.gpu", "%"),
|
||||
("memory.used", "MB"),
|
||||
|
@ -111,9 +115,6 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
|||
gpu_query = ",".join(gpu_stat_keys)
|
||||
|
||||
gpu_id = _get_gpu_id(device.index)
|
||||
nvidia_smi_path = shutil.which("nvidia-smi")
|
||||
if nvidia_smi_path is None:
|
||||
raise FileNotFoundError("nvidia-smi: command not found")
|
||||
result = subprocess.run(
|
||||
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
|
||||
encoding="utf-8",
|
||||
|
@ -130,10 +131,7 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
|||
|
||||
s = result.stdout.strip()
|
||||
stats = [_to_float(x) for x in s.split(", ")]
|
||||
|
||||
gpu_stats = {}
|
||||
for i, (x, unit) in enumerate(gpu_stat_metrics):
|
||||
gpu_stats[f"{x} ({unit})"] = stats[i]
|
||||
gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)}
|
||||
return gpu_stats
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue