From ab069876cb19bb9de0179f74c6f83764876a73ff Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Sun, 26 Sep 2021 21:09:16 -0700 Subject: [PATCH] [1/4] Add get_device_stats to accelerator interface (#9586) --- CHANGELOG.md | 3 + pytorch_lightning/accelerators/accelerator.py | 12 +++ pytorch_lightning/accelerators/cpu.py | 8 ++ pytorch_lightning/accelerators/gpu.py | 81 +++++++++++++++++++ pytorch_lightning/accelerators/tpu.py | 20 ++++- tests/accelerators/test_gpu.py | 36 +++++++++ tests/accelerators/test_tpu.py | 16 ++++ 7 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 tests/accelerators/test_gpu.py create mode 100644 tests/accelerators/test_tpu.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4291bc0744..4538968564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) +- Added `get_device_stats` to the Accelerator Interface and added its implementation for GPU and TPU ([#9586](https://github.com/PyTorchLightning/pytorch-lightning/pull/9586)) + + - Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c6aa2f75f7..ae76fefb8d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -41,6 +41,7 @@ class Accelerator: - CPU - GPU - TPU + - IPU Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. @@ -422,6 +423,17 @@ class Accelerator: """ return self.training_type_plugin.restore_checkpoint_after_pre_dispatch + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + """Gets stats for a given device. + + Args: + device: device for which to get stats + + Returns: + Dictionary of device stats + """ + raise NotImplementedError + def on_train_start(self) -> None: """Called when train begins.""" return self.training_type_plugin.on_train_start() diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 46e74193fb..baa922b6d7 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Union + +import torch + import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,3 +33,7 @@ class CPUAccelerator(Accelerator): raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.") return super().setup(trainer) + + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + """Returns dummy implementation for now.""" + return {} diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 7672f2edea..b33903c2d6 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -13,12 +13,16 @@ # limitations under the License. import logging import os +import shutil +import subprocess +from typing import Any, Dict, List, Union import torch import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 _log = logging.getLogger(__name__) @@ -53,6 +57,83 @@ class GPUAccelerator(Accelerator): devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + """Gets stats for the given GPU device. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If nvidia-smi installation not found + """ + if _TORCH_GREATER_EQUAL_1_8: + return torch.cuda.memory_stats(device) + return _get_nvidia_gpu_stats(device) + def teardown(self) -> None: super().teardown() self._move_optimizer_state(torch.device("cpu")) + + +def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: + """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If nvidia-smi installation not found + """ + gpu_stat_metrics = [ + ("utilization.gpu", "%"), + ("memory.used", "MB"), + ("memory.free", "MB"), + ("utilization.memory", "%"), + ("fan.speed", "%"), + ("temperature.gpu", "°C"), + ("temperature.memory", "°C"), + ] + gpu_stat_keys = [k for k, _ in gpu_stat_metrics] + gpu_query = ",".join(gpu_stat_keys) + + gpu_id = _get_gpu_id(device.index) + nvidia_smi_path = shutil.which("nvidia-smi") + if nvidia_smi_path is None: + raise FileNotFoundError("nvidia-smi: command not found") + result = subprocess.run( + [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], + encoding="utf-8", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True, + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0.0 + + s = result.stdout.strip() + stats = [_to_float(x) for x in s.split(", ")] + + gpu_stats = {} + for i, (x, unit) in enumerate(gpu_stat_metrics): + gpu_stats[f"{x} ({unit})"] = stats[i] + return gpu_stats + + +def _get_gpu_id(device_id: int) -> str: + """Get the unmasked real GPU IDs.""" + # All devices if `CUDA_VISIBLE_DEVICES` unset + default = ",".join(str(i) for i in range(torch.cuda.device_count())) + cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") + return cuda_visible_devices[device_id].strip() diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 1424bf2157..68925ab67a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional, Union import torch from torch.optim import Optimizer @@ -61,3 +61,21 @@ class TPUAccelerator(Accelerator): for opt in self.optimizers: for p, v in opt.state.items(): opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) + + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + """Gets stats for the given TPU device. + + Args: + device: TPU device for which to get stats + + Returns: + A dictionary mapping the metrics (free memory and peak memory) to their values. + """ + memory_info = xm.get_memory_info(device) + free_memory = memory_info["kb_free"] + peak_memory = memory_info["kb_total"] - free_memory + device_stats = { + "avg. free memory (MB)": free_memory, + "avg. peak memory (MB)": peak_memory, + } + return device_stats diff --git a/tests/accelerators/test_gpu.py b/tests/accelerators/test_gpu.py new file mode 100644 index 0000000000..85ce0cd9f0 --- /dev/null +++ b/tests/accelerators/test_gpu.py @@ -0,0 +1,36 @@ +import torch + +from pytorch_lightning.accelerators import GPUAccelerator +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin +from tests.helpers.runif import RunIf + + +@RunIf(min_torch="1.8") +@RunIf(min_gpus=1) +def test_get_torch_gpu_stats(tmpdir): + """Test GPU get_device_stats with Pytorch >= 1.8.0.""" + current_device = torch.device(f"cuda:{torch.cuda.current_device()}") + GPUAccel = GPUAccelerator( + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() + ) + gpu_stats = GPUAccel.get_device_stats(current_device) + fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] + + for f in fields: + assert any(f in h for h in gpu_stats.keys()) + + +@RunIf(max_torch="1.7") +@RunIf(min_gpus=1) +def test_get_nvidia_gpu_stats(tmpdir): + """Test GPU get_device_stats with Pytorch < 1.8.0.""" + current_device = torch.device(f"cuda:{torch.cuda.current_device()}") + GPUAccel = GPUAccelerator( + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() + ) + gpu_stats = GPUAccel.get_device_stats(current_device) + fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"] + + for f in fields: + assert any(f in h for h in gpu_stats.keys()) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py new file mode 100644 index 0000000000..f3a2c50c0e --- /dev/null +++ b/tests/accelerators/test_tpu.py @@ -0,0 +1,16 @@ +from pytorch_lightning.accelerators import TPUAccelerator +from pytorch_lightning.plugins import SingleTPUPlugin +from pytorch_lightning.plugins.training_type import TPUSpawnPlugin +from tests.helpers.runif import RunIf + + +@RunIf(tpu=True) +def test_device_stats_tpu(tmpdir): + """Test TPU get_device_stats.""" + plugin = SingleTPUPlugin(1) + TPUAccel = TPUAccelerator(training_type_plugin=TPUSpawnPlugin(), precision_plugin=plugin) + tpu_stats = TPUAccel.get_device_stats("1") + fields = ["avg. free memory (MB)", "avg. peak memory (MB)"] + + for f in fields: + assert any(f in h for h in tpu_stats.keys())