diff --git a/CHANGELOG.md b/CHANGELOG.md index 79e7be29e0..ae07f5420e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end` +- Changed `metrics_to_scalars` to work with any collection or value ([#7888](https://github.com/PyTorchLightning/pytorch-lightning/pull/7888)) + + - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index bd57470dc2..059006a803 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -12,29 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions to operate on metric values. """ +import numbers +from typing import Any import torch +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException -def metrics_to_scalars(metrics: dict) -> dict: - """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ +def metrics_to_scalars(metrics: Any) -> Any: + """Recursively walk through a collection and convert single-item tensors to scalar values""" - # TODO: this is duplicated in MetricsHolder. should be unified - new_metrics = {} - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - if v.numel() != 1: - raise MisconfigurationException( - f"The metric `{k}` does not contain a single element" - f" thus it cannot be converted to float. Found `{v}`" - ) - v = v.item() + def to_item(value: torch.Tensor) -> numbers.Number: + if value.numel() != 1: + raise MisconfigurationException( + f"The metric `{value}` does not contain a single element" + f" thus it cannot be converted to float." + ) + return value.item() - if isinstance(v, dict): - v = metrics_to_scalars(v) - - new_metrics[k] = v - - return new_metrics + return apply_to_collection(metrics, torch.Tensor, to_item) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2e6234bb98..adf26467e8 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -487,7 +487,7 @@ def test_metric_holder_raises(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - match = "The metric `test` does not contain a single element" + match = "The metric `.*` does not contain a single element" with pytest.raises(MisconfigurationException, match=match): trainer.validate(model) with pytest.raises(MisconfigurationException, match=match):