From 29179dbfcc125aa85716cb1985cd1a0ac2e0a4fd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 23 Jun 2020 06:19:16 -0700 Subject: [PATCH] Fix ROC metric for CUDA tensors (#2304) * Fix ROC metric for CUDA tensors Previously roc metric (and auroc) errors when passed in CUDA tensors, due to torch.tensor construction without specifying device. This fixes the error by using F.pad instead. * Update test_classification.py * Update test_classification.py * chlog * Update test_classification.py * Update test_classification.py * Update tests/metrics/functional/test_classification.py Co-authored-by: Jirka Borovec * Update test_classification.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 ++- .../metrics/functional/classification.py | 4 ++-- .../metrics/functional/test_classification.py | 23 +++++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e3371a3db7..863d0d577b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) -- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) +- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304)) +- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) ## [0.8.1] - 2020-06-19 diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index c5d3613640..4e23898738 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -3,6 +3,7 @@ from functools import wraps from typing import Optional, Tuple, Callable import torch +from torch.nn import functional as F from pytorch_lightning.metrics.functional.reduction import reduce from pytorch_lightning.utilities import rank_zero_warn @@ -500,8 +501,7 @@ def _binary_clf_curve( # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = torch.cat([distinct_value_indices, - torch.tensor([target.size(0) - 1])]) + threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) target = (target == pos_label).to(torch.long) tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 6a06ef4a4a..6aab85fa7b 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -45,25 +45,30 @@ from pytorch_lightning.metrics.functional.classification import ( ]) def test_against_sklearn(sklearn_metric, torch_metric): """Compare PL metrics to sklearn version.""" - pred = torch.randint(10, (500,)) - target = torch.randint(10, (500,)) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + pred = torch.randint(10, (500,), device=device) + target = torch.randint(10, (500,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target)) - pred = torch.randint(10, (200,)) - target = torch.randint(5, (200,)) + pred = torch.randint(10, (200,), device=device) + target = torch.randint(5, (200,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target)) - pred = torch.randint(5, (200,)) - target = torch.randint(10, (200,)) + pred = torch.randint(5, (200,), device=device) + target = torch.randint(10, (200,), device=device) assert torch.allclose( - torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()), dtype=torch.float, device=device), torch_metric(pred, target))