Group metrics generated by `DeviceStatsMonitor` for better visualization (#11254)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
17cb3c70f7
commit
05ed9a201c
|
@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Renamed `training_type_plugin` file to `strategy` ([#11239](https://github.com/PyTorchLightning/pytorch-lightning/pull/11239))
|
||||
|
||||
|
||||
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
|
||||
|
@ -221,6 +224,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
|
||||
|
||||
|
||||
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
|
||||
|
|
|
@ -24,6 +24,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
|
||||
|
||||
|
||||
class DeviceStatsMonitor(Callback):
|
||||
|
@ -54,12 +55,15 @@ class DeviceStatsMonitor(Callback):
|
|||
batch_idx: int,
|
||||
unused: Optional[int] = 0,
|
||||
) -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
||||
|
||||
if not trainer.logger_connector.should_update_logs:
|
||||
return
|
||||
|
||||
device_stats = trainer.accelerator.get_device_stats(pl_module.device)
|
||||
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
|
||||
assert trainer.logger is not None
|
||||
separator = trainer.logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
|
||||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
|
||||
def on_train_batch_end(
|
||||
|
@ -71,14 +75,26 @@ class DeviceStatsMonitor(Callback):
|
|||
batch_idx: int,
|
||||
unused: Optional[int] = 0,
|
||||
) -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
||||
|
||||
if not trainer.logger_connector.should_update_logs:
|
||||
return
|
||||
|
||||
device_stats = trainer.accelerator.get_device_stats(pl_module.device)
|
||||
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
|
||||
assert trainer.logger is not None
|
||||
separator = trainer.logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
|
||||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
|
||||
|
||||
def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
|
||||
return {prefix + "." + k: v for k, v in metrics_dict.items()}
|
||||
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
|
||||
return {prefix + separator + k: v for k, v in metrics_dict.items()}
|
||||
|
||||
|
||||
def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`"
|
||||
" is deprecated in v1.6 and will be removed in v1.8."
|
||||
)
|
||||
sep = ""
|
||||
return _prefix_metric_keys(metrics_dict, prefix, sep)
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import DeviceStatsMonitor
|
||||
from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -128,3 +129,12 @@ def test_device_stats_monitor_no_logger(tmpdir):
|
|||
|
||||
with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_prefix_metric_keys(tmpdir):
|
||||
"""Test that metric key names are converted correctly."""
|
||||
metrics = {"1": 1.0, "2": 2.0, "3": 3.0}
|
||||
prefix = "foo"
|
||||
separator = "."
|
||||
converted_metrics = _prefix_metric_keys(metrics, prefix, separator)
|
||||
assert converted_metrics == {"foo.1": 1.0, "foo.2": 2.0, "foo.3": 3.0}
|
||||
|
|
|
@ -271,6 +271,13 @@ def test_v1_8_0_deprecated_training_type_plugin_property():
|
|||
trainer.training_type_plugin
|
||||
|
||||
|
||||
def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys():
|
||||
from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys
|
||||
|
||||
with pytest.deprecated_call(match="in v1.6 and will be removed in v1.8"):
|
||||
prefix_metric_keys({"foo": 1.0}, "bar")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue