Use `apply_to_collection` in `metrics_to_scalars` (#7888)
* Use `apply_to_collection` in `metrics_to_scalars` * Typing * Update CHANGELOG * Update pytorch_lightning/utilities/metrics.py * Whitespace
This commit is contained in:
parent
0fda862274
commit
b74f8ac149
|
@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
|
||||
|
||||
|
||||
- Changed `metrics_to_scalars` to work with any collection or value ([#7888](https://github.com/PyTorchLightning/pytorch-lightning/pull/7888))
|
||||
|
||||
|
||||
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
|
||||
|
||||
|
|
|
@ -12,29 +12,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions to operate on metric values. """
|
||||
import numbers
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
def metrics_to_scalars(metrics: dict) -> dict:
|
||||
""" Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """
|
||||
def metrics_to_scalars(metrics: Any) -> Any:
|
||||
"""Recursively walk through a collection and convert single-item tensors to scalar values"""
|
||||
|
||||
# TODO: this is duplicated in MetricsHolder. should be unified
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
if v.numel() != 1:
|
||||
def to_item(value: torch.Tensor) -> numbers.Number:
|
||||
if value.numel() != 1:
|
||||
raise MisconfigurationException(
|
||||
f"The metric `{k}` does not contain a single element"
|
||||
f" thus it cannot be converted to float. Found `{v}`"
|
||||
f"The metric `{value}` does not contain a single element"
|
||||
f" thus it cannot be converted to float."
|
||||
)
|
||||
v = v.item()
|
||||
return value.item()
|
||||
|
||||
if isinstance(v, dict):
|
||||
v = metrics_to_scalars(v)
|
||||
|
||||
new_metrics[k] = v
|
||||
|
||||
return new_metrics
|
||||
return apply_to_collection(metrics, torch.Tensor, to_item)
|
||||
|
|
|
@ -487,7 +487,7 @@ def test_metric_holder_raises(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
|
||||
match = "The metric `test` does not contain a single element"
|
||||
match = "The metric `.*` does not contain a single element"
|
||||
with pytest.raises(MisconfigurationException, match=match):
|
||||
trainer.validate(model)
|
||||
with pytest.raises(MisconfigurationException, match=match):
|
||||
|
|
Loading…
Reference in New Issue