Use `apply_to_collection` in `metrics_to_scalars` (#7888)

* Use `apply_to_collection` in `metrics_to_scalars`

* Typing

* Update CHANGELOG

* Update pytorch_lightning/utilities/metrics.py

* Whitespace
This commit is contained in:
Carlos Mocholí 2021-06-08 18:54:32 +02:00 committed by GitHub
parent 0fda862274
commit b74f8ac149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 19 deletions

View File

@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
- Changed `metrics_to_scalars` to work with any collection or value ([#7888](https://github.com/PyTorchLightning/pytorch-lightning/pull/7888))
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

View File

@ -12,29 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to operate on metric values. """
import numbers
from typing import Any
import torch
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def metrics_to_scalars(metrics: dict) -> dict:
""" Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """
def metrics_to_scalars(metrics: Any) -> Any:
"""Recursively walk through a collection and convert single-item tensors to scalar values"""
# TODO: this is duplicated in MetricsHolder. should be unified
new_metrics = {}
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
if v.numel() != 1:
def to_item(value: torch.Tensor) -> numbers.Number:
if value.numel() != 1:
raise MisconfigurationException(
f"The metric `{k}` does not contain a single element"
f" thus it cannot be converted to float. Found `{v}`"
f"The metric `{value}` does not contain a single element"
f" thus it cannot be converted to float."
)
v = v.item()
return value.item()
if isinstance(v, dict):
v = metrics_to_scalars(v)
new_metrics[k] = v
return new_metrics
return apply_to_collection(metrics, torch.Tensor, to_item)

View File

@ -487,7 +487,7 @@ def test_metric_holder_raises(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
match = "The metric `test` does not contain a single element"
match = "The metric `.*` does not contain a single element"
with pytest.raises(MisconfigurationException, match=match):
trainer.validate(model)
with pytest.raises(MisconfigurationException, match=match):