diff --git a/CHANGELOG.md b/CHANGELOG.md index e3371a3db7..863d0d577b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) -- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) +- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304)) +- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) ## [0.8.1] - 2020-06-19 diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index c5d3613640..4e23898738 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -3,6 +3,7 @@ from functools import wraps from typing import Optional, Tuple, Callable import torch +from torch.nn import functional as F from pytorch_lightning.metrics.functional.reduction import reduce from pytorch_lightning.utilities import rank_zero_warn @@ -500,8 +501,7 @@ def _binary_clf_curve( # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = torch.cat([distinct_value_indices, - torch.tensor([target.size(0) - 1])]) + threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) target = (target == pos_label).to(torch.long) tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 6a06ef4a4a..6aab85fa7b 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -45,25 +45,30 @@ from pytorch_lightning.metrics.functional.classification import ( ]) def test_against_sklearn(sklearn_metric, torch_metric): """Compare PL metrics to sklearn version.""" - pred = torch.randint(10, (500,)) - target = torch.randint(10, (500,)) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + pred = torch.randint(10, (500,), device=device) + target = torch.randint(10, (500,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target)) - pred = torch.randint(10, (200,)) - target = torch.randint(5, (200,)) + pred = torch.randint(10, (200,), device=device) + target = torch.randint(5, (200,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target)) - pred = torch.randint(5, (200,)) - target = torch.randint(10, (200,)) + pred = torch.randint(5, (200,), device=device) + target = torch.randint(10, (200,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target))