From b3fe836656008fb6688917c8692f190a389f30f7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 24 Apr 2021 02:25:33 -0700 Subject: [PATCH] Move metrics_to_scalars to a dedicated utilities file (#7180) * rm-trainer-logging * Update CHANGELOG.md * Update metrics.py * Update logging.py * Update metrics.py --- CHANGELOG.md | 6 +++ .../logger_connector/logger_connector.py | 3 +- pytorch_lightning/trainer/logging.py | 33 ++++++--------- pytorch_lightning/utilities/metrics.py | 40 +++++++++++++++++++ tests/deprecated_api/test_remove_1-5.py | 6 +++ 5 files changed, 67 insertions(+), 21 deletions(-) create mode 100644 pytorch_lightning/utilities/metrics.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a484602c7..c71f55f8a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `teardown` hook to `ClusterEnvironment` ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) +- Added utils for metrics to scalar conversions ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180)) + + - Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) @@ -146,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180)) + + - Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ca2f5b18cb..ebf8a13afe 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store im from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.metrics import metrics_to_scalars class LoggerConnector: @@ -210,7 +211,7 @@ class LoggerConnector: metrics.update(grad_norm_dic) # turn all tensors to scalars - scalar_metrics = self.trainer.metrics_to_scalars(metrics) + scalar_metrics = metrics_to_scalars(metrics) if "step" in scalar_metrics and step is None: step = scalar_metrics.pop("step") diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 78a5131538..0a59b9d8d4 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,28 +14,21 @@ from abc import ABC -import torch - -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities.metrics import metrics_to_scalars as new_metrics_to_scalars class TrainerLoggingMixin(ABC): + """ + TODO: Remove this class in v1.5. - def metrics_to_scalars(self, metrics): - new_metrics = {} - # TODO: this is duplicated in MetricsHolder. should be unified - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - if v.numel() != 1: - raise MisconfigurationException( - f"The metric `{k}` does not contain a single element" - f" thus it cannot be converted to float. Found `{v}`" - ) - v = v.item() + Use the utilities from ``pytorch_lightning.utilities.metrics`` instead. + """ - if isinstance(v, dict): - v = self.metrics_to_scalars(v) - - new_metrics[k] = v - - return new_metrics + def metrics_to_scalars(self, metrics: dict) -> dict: + rank_zero_deprecation( + "Internal: TrainerLoggingMixin.metrics_to_scalars is deprecated in v1.3" + " and will be removed in v1.5." + " Use `pytorch_lightning.utilities.metrics.metrics_to_scalars` instead." + ) + return new_metrics_to_scalars(metrics) diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py new file mode 100644 index 0000000000..bd57470dc2 --- /dev/null +++ b/pytorch_lightning/utilities/metrics.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions to operate on metric values. """ + +import torch + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def metrics_to_scalars(metrics: dict) -> dict: + """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ + + # TODO: this is duplicated in MetricsHolder. should be unified + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + if v.numel() != 1: + raise MisconfigurationException( + f"The metric `{k}` does not contain a single element" + f" thus it cannot be converted to float. Found `{v}`" + ) + v = v.item() + + if isinstance(v, dict): + v = metrics_to_scalars(v) + + new_metrics[k] = v + + return new_metrics diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index b1227e20d1..0a838da2fa 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -242,3 +242,9 @@ def test_v1_5_0_auto_move_data(): @auto_move_data def bar(self): pass + + +def test_v1_5_0_trainer_logging_mixin(tmpdir): + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + trainer.metrics_to_scalars({})