From 3f78c4ca7a89fed8f79d9015e9510f0a0dcc46b5 Mon Sep 17 00:00:00 2001 From: Eric Wiener Date: Tue, 10 May 2022 03:57:38 -0700 Subject: [PATCH] Track CPU stats with DeviceStatsMonitor (#11795) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Jirka Borovec Co-authored-by: Rohit Gupta Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 + dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- docs/source/tuning/profiler_basic.rst | 3 + pytorch_lightning/accelerators/cpu.py | 26 ++++++- .../callbacks/device_stats_monitor.py | 77 +++++++++++++------ pytorch_lightning/utilities/imports.py | 3 +- requirements/test.txt | 1 + tests/callbacks/test_device_stats_monitor.py | 68 ++++++++++++++-- tests/helpers/runif.py | 9 ++- 9 files changed, 158 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd36750c3..c803239d5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -590,6 +590,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with resuming from a checkpoint trained with QAT ([#11346](https://github.com/PyTorchLightning/pytorch-lightning/pull/11346)) +- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795)) + + ## [1.5.10] - 2022-02-08 ### Fixed diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index ea7d5b6e09..c474ea6641 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -32,12 +32,12 @@ local tputests = base.BaseTest { pip install -e . echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" + # TODO (@kaushikb11): Add device stats tests here coverage run --source=pytorch_lightning -m pytest -v --capture=no \ tests/strategies/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ tests/accelerators/test_tpu.py \ - tests/callbacks/test_device_stats_monitor.py \ tests/models/test_tpu.py test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" diff --git a/docs/source/tuning/profiler_basic.rst b/docs/source/tuning/profiler_basic.rst index 899e657904..c3ddc114dc 100644 --- a/docs/source/tuning/profiler_basic.rst +++ b/docs/source/tuning/profiler_basic.rst @@ -119,3 +119,6 @@ This can be measured with the :class:`~pytorch_lightning.callbacks.device_stats_ from pytorch_lightning.callbacks import DeviceStatsMonitor trainer = Trainer(callbacks=[DeviceStatsMonitor()]) + +CPU metrics will be tracked by default on the CPU accelerator. To enable it for other accelerators set ``DeviceStatsMonitor(cpu_stats=True)``. To disable logging +CPU metrics, you can specify ``DeviceStatsMonitor(cpu_stats=False)``. diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 69bef65644..fea8ee70d1 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE from pytorch_lightning.utilities.types import _DEVICE @@ -35,8 +36,8 @@ class CPUAccelerator(Accelerator): raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: - """CPU device stats aren't supported yet.""" - return {} + """Get CPU stats from ``psutil`` package.""" + return get_cpu_stats() @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> int: @@ -67,3 +68,24 @@ class CPUAccelerator(Accelerator): cls, description=f"{cls.__class__.__name__}", ) + + +# CPU device metrics +_CPU_VM_PERCENT = "cpu_vm_percent" +_CPU_PERCENT = "cpu_percent" +_CPU_SWAP_PERCENT = "cpu_swap_percent" + + +def get_cpu_stats() -> Dict[str, float]: + if not _PSUTIL_AVAILABLE: + raise ModuleNotFoundError( + "Fetching CPU device stats requires `psutil` to be installed." + " Install it by running `pip install -U psutil`." + ) + import psutil + + return { + _CPU_VM_PERCENT: psutil.virtual_memory().percent, + _CPU_PERCENT: psutil.cpu_percent(), + _CPU_SWAP_PERCENT: psutil.swap_memory().percent, + } diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index 42ea7e5d1a..00fd79d0f7 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -23,8 +23,9 @@ from typing import Any, Dict, Optional import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.utilities.warnings import rank_zero_deprecation class DeviceStatsMonitor(Callback): @@ -32,6 +33,13 @@ class DeviceStatsMonitor(Callback): Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``. + Args: + cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU. + It will raise a warning if ``psutil`` is not installed till v1.9.0. + If ``True``, it will log CPU stats regardless of the accelerator, and it will + raise an exception if ``psutil`` is not installed. + If ``False``, it will not log CPU stats regardless of the accelerator. + Raises: MisconfigurationException: If ``Trainer`` has no logger. @@ -43,45 +51,68 @@ class DeviceStatsMonitor(Callback): >>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP """ - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - if not trainer.loggers: - raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.") + def __init__(self, cpu_stats: Optional[bool] = None) -> None: + self._cpu_stats = cpu_stats - def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + def setup( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + stage: Optional[str] = None, ) -> None: + if stage != "fit": + return + if not trainer.loggers: raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") + # warn in setup to warn once + device = trainer.strategy.root_device + if self._cpu_stats is None and device.type == "cpu" and not _PSUTIL_AVAILABLE: + # TODO: raise an exception from v1.9 + rank_zero_warn( + "`DeviceStatsMonitor` will not log CPU stats as `psutil` is not installed." + " To install `psutil`, run `pip install psutil`." + " It will raise an exception if `psutil` is not installed post v1.9.0." + ) + self._cpu_stats = False + + def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None: if not trainer._logger_connector.should_update_logs: return device = trainer.strategy.root_device + if self._cpu_stats is False and device.type == "cpu": + # cpu stats are disabled + return + device_stats = trainer.accelerator.get_device_stats(device) + + if self._cpu_stats and device.type != "cpu": + # Don't query CPU stats twice if CPU is accelerator + from pytorch_lightning.accelerators.cpu import get_cpu_stats + + device_stats.update(get_cpu_stats()) + for logger in trainer.loggers: separator = logger.group_separator - prefixed_device_stats = _prefix_metric_keys( - device_stats, f"{self.__class__.__qualname__}.on_train_batch_start", separator - ) + prefixed_device_stats = _prefix_metric_keys(device_stats, f"{self.__class__.__qualname__}.{key}", separator) logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + unused: Optional[int] = 0, + ) -> None: + self._get_and_log_device_stats(trainer, "on_train_batch_start") + def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - if not trainer.loggers: - raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") - - if not trainer._logger_connector.should_update_logs: - return - - device = trainer.strategy.root_device - device_stats = trainer.accelerator.get_device_stats(device) - for logger in trainer.loggers: - separator = logger.group_separator - prefixed_device_stats = _prefix_metric_keys( - device_stats, f"{self.__class__.__qualname__}.on_train_batch_end", separator - ) - logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) + self._get_and_log_device_stats(trainer, "on_train_batch_end") def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index f2f73a7d89..247e79f713 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -105,6 +105,7 @@ _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group") +_HABANA_FRAMEWORK_AVAILABLE = _package_available("habana_frameworks") _HIVEMIND_AVAILABLE = _package_available("hivemind") _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _package_available("hydra") @@ -115,7 +116,7 @@ _NEPTUNE_AVAILABLE = _package_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") _OMEGACONF_AVAILABLE = _package_available("omegaconf") _POPTORCH_AVAILABLE = _package_available("poptorch") -_HABANA_FRAMEWORK_AVAILABLE = _package_available("habana_frameworks") +_PSUTIL_AVAILABLE = _package_available("psutil") _RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCHTEXT_AVAILABLE = _package_available("torchtext") diff --git a/requirements/test.txt b/requirements/test.txt index a744f57382..e43a2a3700 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,4 +12,5 @@ pytest-forked cloudpickle>=1.3 scikit-learn>0.22.1 onnxruntime +psutil # for `DeviceStatsMonitor` pandas # needed in benchmarks diff --git a/tests/callbacks/test_device_stats_monitor.py b/tests/callbacks/test_device_stats_monitor.py index c3108a8cef..35fa91698b 100644 --- a/tests/callbacks/test_device_stats_monitor.py +++ b/tests/callbacks/test_device_stats_monitor.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional +from unittest import mock +from unittest.mock import Mock import pytest +import torch from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats from pytorch_lightning.callbacks import DeviceStatsMonitor from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys from pytorch_lightning.loggers import CSVLogger @@ -34,9 +38,13 @@ def test_device_stats_gpu_from_torch(tmpdir): class DebugLogger(CSVLogger): @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] + fields = [ + "allocated_bytes.all.freed", + "inactive_split.all.peak", + "reserved_bytes.large_pool.peak", + ] for f in fields: - assert any(f in h for h in metrics.keys()) + assert any(f in h for h in metrics) trainer = Trainer( default_root_dir=tmpdir, @@ -54,6 +62,41 @@ def test_device_stats_gpu_from_torch(tmpdir): trainer.fit(model) +@RunIf(psutil=True) +@pytest.mark.parametrize("cpu_stats", (None, True, False)) +@mock.patch("pytorch_lightning.accelerators.cpu.get_cpu_stats", side_effect=get_cpu_stats) +def test_device_stats_cpu(cpu_stats_mock, tmpdir, cpu_stats): + """Test CPU stats are logged when no accelerator is used.""" + model = BoringModel() + CPU_METRIC_KEYS = (_CPU_VM_PERCENT, _CPU_SWAP_PERCENT, _CPU_PERCENT) + + class DebugLogger(CSVLogger): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + enabled = cpu_stats is not False + for f in CPU_METRIC_KEYS: + has_cpu_metrics = any(f in h for h in metrics) + assert has_cpu_metrics if enabled else not has_cpu_metrics + + device_stats = DeviceStatsMonitor(cpu_stats=cpu_stats) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + log_every_n_steps=1, + callbacks=device_stats, + logger=DebugLogger(tmpdir), + enable_checkpointing=False, + enable_progress_bar=False, + accelerator="cpu", + ) + trainer.fit(model) + + expected = 4 if cpu_stats is not False else 0 # (batch_start + batch_end) * train_batches + assert cpu_stats_mock.call_count == expected + + +@pytest.mark.skipif(True, reason="TODO (@kaushikb11): fix this test, timeout") @RunIf(tpu=True) def test_device_stats_monitor_tpu(tmpdir): """Test TPU stats are logged using a logger.""" @@ -66,14 +109,14 @@ def test_device_stats_monitor_tpu(tmpdir): def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: fields = ["avg. free memory (MB)", "avg. peak memory (MB)"] for f in fields: - assert any(f in h for h in metrics.keys()) + assert any(f in h for h in metrics) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_train_batches=1, + limit_train_batches=2, accelerator="tpu", - devices=8, + devices=1, log_every_n_steps=1, callbacks=[device_stats], logger=DebugLogger(tmpdir), @@ -99,7 +142,7 @@ def test_device_stats_monitor_no_logger(tmpdir): enable_progress_bar=False, ) - with pytest.raises(MisconfigurationException, match="Trainer that has no logger."): + with pytest.raises(MisconfigurationException, match="Cannot use `DeviceStatsMonitor` callback."): trainer.fit(model) @@ -110,3 +153,16 @@ def test_prefix_metric_keys(tmpdir): separator = "." converted_metrics = _prefix_metric_keys(metrics, prefix, separator) assert converted_metrics == {"foo.1": 1.0, "foo.2": 2.0, "foo.3": 3.0} + + +def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch): + """Test that warning is raised when psutil is not available.""" + import pytorch_lightning.callbacks.device_stats_monitor as imports + + monkeypatch.setattr(imports, "_PSUTIL_AVAILABLE", False) + monitor = DeviceStatsMonitor() + trainer = Trainer() + assert trainer.strategy.root_device == torch.device("cpu") + # TODO: raise an exception from v1.9 + with pytest.warns(UserWarning, match="psutil` is not installed"): + monitor.setup(trainer, Mock(), "fit") diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 99d64ebd01..c032ca61e9 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -20,7 +20,7 @@ import torch from packaging.version import Version from pkg_resources import get_distribution -from pytorch_lightning.utilities import ( +from pytorch_lightning.utilities.imports import ( _APEX_AVAILABLE, _BAGUA_AVAILABLE, _DEEPSPEED_AVAILABLE, @@ -31,6 +31,7 @@ from pytorch_lightning.utilities import ( _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, + _PSUTIL_AVAILABLE, _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, _TORCH_QUANTIZE_AVAILABLE, @@ -85,6 +86,7 @@ class RunIf: omegaconf: bool = False, slow: bool = False, bagua: bool = False, + psutil: bool = False, hivemind: bool = False, **kwargs, ): @@ -113,6 +115,7 @@ class RunIf: omegaconf: Require that omry/omegaconf is installed. slow: Mark the test as slow, our CI will run it in a separate job. bagua: Require that BaguaSys/bagua is installed. + psutil: Require that psutil is installed. hivemind: Require that Hivemind is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ @@ -234,6 +237,10 @@ class RunIf: conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin")) reasons.append("Bagua") + if psutil: + conditions.append(not _PSUTIL_AVAILABLE) + reasons.append("psutil") + if hivemind: conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin")) reasons.append("Hivemind")