Fix ROC metric for CUDA tensors (#2304)

* Fix ROC metric for CUDA tensors

Previously roc metric (and auroc) errors when passed in CUDA tensors,
due to torch.tensor construction without specifying device.
This fixes the error by using F.pad instead.

* Update test_classification.py

* Update test_classification.py

* chlog

* Update test_classification.py

* Update test_classification.py

* Update tests/metrics/functional/test_classification.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update test_classification.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Tri Dao 2020-06-23 06:19:16 -07:00 committed by GitHub
parent 92f122e0df
commit 29179dbfcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 12 deletions

View File

@ -20,8 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304))
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
## [0.8.1] - 2020-06-19

View File

@ -3,6 +3,7 @@ from functools import wraps
from typing import Optional, Tuple, Callable
import torch
from torch.nn import functional as F
from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.utilities import rank_zero_warn
@ -500,8 +501,7 @@ def _binary_clf_curve(
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = torch.cat([distinct_value_indices,
torch.tensor([target.size(0) - 1])])
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]

View File

@ -45,25 +45,30 @@ from pytorch_lightning.metrics.functional.classification import (
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
pred = torch.randint(10, (500,))
target = torch.randint(10, (500,))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pred = torch.randint(10, (500,), device=device)
target = torch.randint(10, (500,), device=device)
assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))
pred = torch.randint(10, (200,))
target = torch.randint(5, (200,))
pred = torch.randint(10, (200,), device=device)
target = torch.randint(5, (200,), device=device)
assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))
pred = torch.randint(5, (200,))
target = torch.randint(10, (200,))
pred = torch.randint(5, (200,), device=device)
target = torch.randint(10, (200,), device=device)
assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))