Use `apply_to_collection` in `metrics_to_scalars` (#7888)

* Use `apply_to_collection` in `metrics_to_scalars`

* Typing

* Update CHANGELOG

* Update pytorch_lightning/utilities/metrics.py

* Whitespace
This commit is contained in:
Carlos Mocholí 2021-06-08 18:54:32 +02:00 committed by GitHub
parent 0fda862274
commit b74f8ac149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 19 deletions

View File

@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end` - Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
- Changed `metrics_to_scalars` to work with any collection or value ([#7888](https://github.com/PyTorchLightning/pytorch-lightning/pull/7888))
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

View File

@ -12,29 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Helper functions to operate on metric values. """ """Helper functions to operate on metric values. """
import numbers
from typing import Any
import torch import torch
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
def metrics_to_scalars(metrics: dict) -> dict: def metrics_to_scalars(metrics: Any) -> Any:
""" Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ """Recursively walk through a collection and convert single-item tensors to scalar values"""
# TODO: this is duplicated in MetricsHolder. should be unified def to_item(value: torch.Tensor) -> numbers.Number:
new_metrics = {} if value.numel() != 1:
for k, v in metrics.items(): raise MisconfigurationException(
if isinstance(v, torch.Tensor): f"The metric `{value}` does not contain a single element"
if v.numel() != 1: f" thus it cannot be converted to float."
raise MisconfigurationException( )
f"The metric `{k}` does not contain a single element" return value.item()
f" thus it cannot be converted to float. Found `{v}`"
)
v = v.item()
if isinstance(v, dict): return apply_to_collection(metrics, torch.Tensor, to_item)
v = metrics_to_scalars(v)
new_metrics[k] = v
return new_metrics

View File

@ -487,7 +487,7 @@ def test_metric_holder_raises(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
match = "The metric `test` does not contain a single element" match = "The metric `.*` does not contain a single element"
with pytest.raises(MisconfigurationException, match=match): with pytest.raises(MisconfigurationException, match=match):
trainer.validate(model) trainer.validate(model)
with pytest.raises(MisconfigurationException, match=match): with pytest.raises(MisconfigurationException, match=match):