From 05ed9a201c24e08c2b4d3df4735296758ddcd6a5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 Jan 2022 05:26:17 -0800 Subject: [PATCH] Group metrics generated by `DeviceStatsMonitor` for better visualization (#11254) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec --- CHANGELOG.md | 7 +++++ .../callbacks/device_stats_monitor.py | 28 +++++++++++++++---- tests/callbacks/test_device_stats_monitor.py | 10 +++++++ tests/deprecated_api/test_remove_1-8.py | 7 +++++ 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index daa13bdea0..486ef208c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `training_type_plugin` file to `strategy` ([#11239](https://github.com/PyTorchLightning/pytorch-lightning/pull/11239)) +- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) @@ -221,6 +224,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) + +- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index 016d2015a8..e43783ac75 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -24,6 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class DeviceStatsMonitor(Callback): @@ -54,12 +55,15 @@ class DeviceStatsMonitor(Callback): batch_idx: int, unused: Optional[int] = 0, ) -> None: + if not trainer.logger: + raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") + if not trainer.logger_connector.should_update_logs: return device_stats = trainer.accelerator.get_device_stats(pl_module.device) - prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start") - assert trainer.logger is not None + separator = trainer.logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def on_train_batch_end( @@ -71,14 +75,26 @@ class DeviceStatsMonitor(Callback): batch_idx: int, unused: Optional[int] = 0, ) -> None: + if not trainer.logger: + raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") + if not trainer.logger_connector.should_update_logs: return device_stats = trainer.accelerator.get_device_stats(pl_module.device) - prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end") - assert trainer.logger is not None + separator = trainer.logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) -def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]: - return {prefix + "." + k: v for k, v in metrics_dict.items()} +def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: + return {prefix + separator + k: v for k, v in metrics_dict.items()} + + +def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]: + rank_zero_deprecation( + "`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`" + " is deprecated in v1.6 and will be removed in v1.8." + ) + sep = "" + return _prefix_metric_keys(metrics_dict, prefix, sep) diff --git a/tests/callbacks/test_device_stats_monitor.py b/tests/callbacks/test_device_stats_monitor.py index 5488cd44de..e5b4516182 100644 --- a/tests/callbacks/test_device_stats_monitor.py +++ b/tests/callbacks/test_device_stats_monitor.py @@ -17,6 +17,7 @@ import pytest from pytorch_lightning import Trainer from pytorch_lightning.callbacks import DeviceStatsMonitor +from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -128,3 +129,12 @@ def test_device_stats_monitor_no_logger(tmpdir): with pytest.raises(MisconfigurationException, match="Trainer that has no logger."): trainer.fit(model) + + +def test_prefix_metric_keys(tmpdir): + """Test that metric key names are converted correctly.""" + metrics = {"1": 1.0, "2": 2.0, "3": 3.0} + prefix = "foo" + separator = "." + converted_metrics = _prefix_metric_keys(metrics, prefix, separator) + assert converted_metrics == {"foo.1": 1.0, "foo.2": 2.0, "foo.3": 3.0} diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 750b6ae2b2..8eaebc0b51 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -271,6 +271,13 @@ def test_v1_8_0_deprecated_training_type_plugin_property(): trainer.training_type_plugin +def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys(): + from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys + + with pytest.deprecated_call(match="in v1.6 and will be removed in v1.8"): + prefix_metric_keys({"foo": 1.0}, "bar") + + @pytest.mark.parametrize( "cls", [