[metrics] Accuracy num_classes error fix (#3764)

* change accuracy error to warning

* changelog
This commit is contained in:
Nicki Skafte 2020-10-01 13:00:42 +02:00 committed by GitHub
parent 8be002ccc7
commit 9a7d1a1876
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 7 deletions

View File

@ -89,6 +89,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
## [0.9.0] - YYYY-MM-DD
### Added

View File

@ -87,8 +87,10 @@ def get_num_classes(
if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(f'You have set {num_classes} number of classes if different from'
f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes')
rank_zero_warn(f'You have set {num_classes} number of classes which is'
f' different from predicted ({num_pred_classes}) and'
f' target ({num_target_classes}) number of classes',
RuntimeWarning)
return num_classes
@ -266,9 +268,6 @@ def accuracy(
tensor(0.7500)
"""
if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero")
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

View File

@ -194,8 +194,12 @@ def test_multilabel_accuracy():
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))
with pytest.raises(RuntimeError):
accuracy(y2, torch.zeros_like(y2), class_reduction='none')
# num_classes does not match extracted number from input we expect a warning
with pytest.warns(RuntimeWarning,
match=r'You have set .* number of classes which is'
r' different from predicted (.*) and'
r' target (.*) number of classes'):
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)
def test_accuracy():