Fix ROC metric for CUDA tensors (#2304)
* Fix ROC metric for CUDA tensors Previously roc metric (and auroc) errors when passed in CUDA tensors, due to torch.tensor construction without specifying device. This fixes the error by using F.pad instead. * Update test_classification.py * Update test_classification.py * chlog * Update test_classification.py * Update test_classification.py * Update tests/metrics/functional/test_classification.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update test_classification.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
92f122e0df
commit
29179dbfcc
|
@ -20,8 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
|
||||
|
||||
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
|
||||
- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304))
|
||||
|
||||
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
|
||||
|
||||
## [0.8.1] - 2020-06-19
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from functools import wraps
|
|||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
@ -500,8 +501,7 @@ def _binary_clf_curve(
|
|||
# the indices associated with the distinct values. We also
|
||||
# concatenate a value for the end of the curve.
|
||||
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
|
||||
threshold_idxs = torch.cat([distinct_value_indices,
|
||||
torch.tensor([target.size(0) - 1])])
|
||||
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
|
||||
|
||||
target = (target == pos_label).to(torch.long)
|
||||
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||||
|
|
|
@ -45,25 +45,30 @@ from pytorch_lightning.metrics.functional.classification import (
|
|||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
pred = torch.randint(10, (500,))
|
||||
target = torch.randint(10, (500,))
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
pred = torch.randint(10, (500,), device=device)
|
||||
target = torch.randint(10, (500,), device=device)
|
||||
|
||||
assert torch.allclose(
|
||||
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
|
||||
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
|
||||
torch_metric(pred, target))
|
||||
|
||||
pred = torch.randint(10, (200,))
|
||||
target = torch.randint(5, (200,))
|
||||
pred = torch.randint(10, (200,), device=device)
|
||||
target = torch.randint(5, (200,), device=device)
|
||||
|
||||
assert torch.allclose(
|
||||
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
|
||||
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
|
||||
torch_metric(pred, target))
|
||||
|
||||
pred = torch.randint(5, (200,))
|
||||
target = torch.randint(10, (200,))
|
||||
pred = torch.randint(5, (200,), device=device)
|
||||
target = torch.randint(10, (200,), device=device)
|
||||
|
||||
assert torch.allclose(
|
||||
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
|
||||
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
|
||||
torch_metric(pred, target))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue