[metrics] Accuracy num_classes error fix (#3764)
* change accuracy error to warning * changelog
This commit is contained in:
parent
8be002ccc7
commit
9a7d1a1876
|
@ -89,6 +89,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
|
||||
|
||||
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
|
||||
|
||||
## [0.9.0] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
|
|
@ -87,8 +87,10 @@ def get_num_classes(
|
|||
if num_classes is None:
|
||||
num_classes = num_all_classes
|
||||
elif num_classes != num_all_classes:
|
||||
rank_zero_warn(f'You have set {num_classes} number of classes if different from'
|
||||
f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes')
|
||||
rank_zero_warn(f'You have set {num_classes} number of classes which is'
|
||||
f' different from predicted ({num_pred_classes}) and'
|
||||
f' target ({num_target_classes}) number of classes',
|
||||
RuntimeWarning)
|
||||
return num_classes
|
||||
|
||||
|
||||
|
@ -266,9 +268,6 @@ def accuracy(
|
|||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
if not (target > 0).any() and num_classes is None:
|
||||
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
|
||||
pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
|
|
|
@ -194,8 +194,12 @@ def test_multilabel_accuracy():
|
|||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
|
||||
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
accuracy(y2, torch.zeros_like(y2), class_reduction='none')
|
||||
# num_classes does not match extracted number from input we expect a warning
|
||||
with pytest.warns(RuntimeWarning,
|
||||
match=r'You have set .* number of classes which is'
|
||||
r' different from predicted (.*) and'
|
||||
r' target (.*) number of classes'):
|
||||
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
|
|
Loading…
Reference in New Issue