Rework of Sklearn Metrics (#1327)
* Create utils.py * Create __init__.py * redo sklearn metrics * add some more metrics * add sklearn metrics * Create __init__.py * redo sklearn metrics * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * add sklearn metrics * start adding sklearn tests * fix typo * return x and y only for curves * fix typo * add missing tests for sklearn funcs * imports * __all__ * imports * fix sklearn arguments * fix imports * update requirements * Update CHANGELOG.md * Update test_sklearn_metrics.py * formatting * formatting * format * fix all warnings and formatting problems * Update environment.yml * Update requirements-extra.txt * Update environment.yml * Update requirements-extra.txt * fix all warnings and formatting problems * Update CHANGELOG.md * docs * inherit * docs inherit. * docs * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * docs * req * min * Apply suggestions from code review Co-authored-by: Tullie Murrell <tulliemurrell@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Tullie Murrell <tulliemurrell@gmail.com>
This commit is contained in:
parent
16a7326e52
commit
bd49b07fbb
|
@ -64,8 +64,12 @@ references:
|
|||
name: Make Documentation
|
||||
command: |
|
||||
# First run the same pipeline as Read-The-Docs
|
||||
sudo apt-get update && sudo apt-get install -y cmake
|
||||
sudo pip install -r docs/requirements.txt
|
||||
# apt-get update && apt-get install -y cmake
|
||||
# using: https://hub.docker.com/r/readthedocs/build
|
||||
# we need to use py3.7 ot higher becase of an issue with metaclass inheritence
|
||||
pyenv global 3.7.3
|
||||
python --version
|
||||
pip install -r docs/requirements.txt
|
||||
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
|
||||
|
||||
test_docs: &test_docs
|
||||
|
@ -81,7 +85,7 @@ jobs:
|
|||
|
||||
Build-Docs:
|
||||
docker:
|
||||
- image: circleci/python:3.7
|
||||
- image: readthedocs/build:latest
|
||||
steps:
|
||||
- checkout
|
||||
- *make_docs
|
||||
|
|
|
@ -68,9 +68,9 @@ jobs:
|
|||
- name: Set min. dependencies
|
||||
if: matrix.requires == 'minimal'
|
||||
run: |
|
||||
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
|
||||
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
|
||||
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
|
||||
python -c "req = open('requirements.txt').read().replace('>=', '==') ; open('requirements.txt', 'w').write(req)"
|
||||
python -c "req = open('requirements-extra.txt').read().replace('>=', '==') ; open('requirements-extra.txt', 'w').write(req)"
|
||||
python -c "req = open('tests/requirements-devel.txt').read().replace('>=', '==') ; open('tests/requirements-devel.txt', 'w').write(req)"
|
||||
|
||||
# Note: This uses an internal pip API and may not always work
|
||||
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
|
||||
|
|
|
@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file.
|
|||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
|
||||
## [unreleased] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
@ -23,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Added
|
||||
|
||||
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
|
||||
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
|
||||
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
|
||||
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
|
||||
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
|
||||
|
|
|
@ -90,6 +90,7 @@ extensions = [
|
|||
'sphinx.ext.linkcode',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.imgmath',
|
||||
'recommonmark',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
# 'm2r',
|
||||
|
|
|
@ -26,6 +26,10 @@ dependencies:
|
|||
- autopep8
|
||||
- check-manifest
|
||||
- twine==1.13.0
|
||||
- pillow<7.0.0
|
||||
- scipy>=0.13.3
|
||||
- scikit-learn>=0.20.0
|
||||
|
||||
|
||||
- pip:
|
||||
- test-tube>=0.7.5
|
||||
|
|
|
@ -27,6 +27,8 @@ from pathlib import Path
|
|||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Generator, Union
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -47,7 +49,7 @@ DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered
|
|||
# --- Utility functions ---
|
||||
|
||||
|
||||
def _make_trainable(module: torch.nn.Module) -> None:
|
||||
def _make_trainable(module: Module) -> None:
|
||||
"""Unfreezes a given module.
|
||||
|
||||
Args:
|
||||
|
@ -58,7 +60,7 @@ def _make_trainable(module: torch.nn.Module) -> None:
|
|||
module.train()
|
||||
|
||||
|
||||
def _recursive_freeze(module: torch.nn.Module,
|
||||
def _recursive_freeze(module: Module,
|
||||
train_bn: bool = True) -> None:
|
||||
"""Freezes the layers of a given module.
|
||||
|
||||
|
@ -80,7 +82,7 @@ def _recursive_freeze(module: torch.nn.Module,
|
|||
_recursive_freeze(module=child, train_bn=train_bn)
|
||||
|
||||
|
||||
def freeze(module: torch.nn.Module,
|
||||
def freeze(module: Module,
|
||||
n: Optional[int] = None,
|
||||
train_bn: bool = True) -> None:
|
||||
"""Freezes the layers up to index n (if n is not None).
|
||||
|
@ -101,7 +103,7 @@ def freeze(module: torch.nn.Module,
|
|||
_make_trainable(module=child)
|
||||
|
||||
|
||||
def filter_params(module: torch.nn.Module,
|
||||
def filter_params(module: Module,
|
||||
train_bn: bool = True) -> Generator:
|
||||
"""Yields the trainable parameters of a given module.
|
||||
|
||||
|
@ -124,7 +126,7 @@ def filter_params(module: torch.nn.Module,
|
|||
yield param
|
||||
|
||||
|
||||
def _unfreeze_and_add_param_group(module: torch.nn.Module,
|
||||
def _unfreeze_and_add_param_group(module: Module,
|
||||
optimizer: Optimizer,
|
||||
lr: Optional[float] = None,
|
||||
train_bn: bool = True):
|
||||
|
|
|
@ -4,9 +4,10 @@ Module to describe gradients
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
class GradInformation(torch.nn.Module):
|
||||
class GradInformation(Module):
|
||||
|
||||
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
|
||||
"""Compute each parameter's gradient's norm and their overall norm.
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
|
||||
|
@ -14,7 +15,7 @@ else:
|
|||
APEX_AVAILABLE = True
|
||||
|
||||
|
||||
class ModelHooks(torch.nn.Module):
|
||||
class ModelHooks(Module):
|
||||
|
||||
# TODO: remove in v0.9.0
|
||||
def on_sanity_check_start(self):
|
||||
|
|
|
@ -22,3 +22,9 @@ inputs to and outputs from numpy as well as automated ddp syncing.
|
|||
|
||||
|
||||
"""
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.sklearn import (
|
||||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.nn import Module
|
||||
|
||||
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -11,7 +12,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixi
|
|||
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
|
||||
|
||||
|
||||
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
|
||||
class Metric(ABC, DeviceDtypeModuleMixin, Module):
|
||||
"""
|
||||
Abstract base class for metric implementation.
|
||||
|
||||
|
|
|
@ -0,0 +1,716 @@
|
|||
from typing import Any, Optional, Union, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as lightning_logger
|
||||
from pytorch_lightning.metrics.metric import NumpyMetric
|
||||
|
||||
__all__ = [
|
||||
'SklearnMetric',
|
||||
'Accuracy',
|
||||
'AveragePrecision',
|
||||
'AUC',
|
||||
'ConfusionMatrix',
|
||||
'F1',
|
||||
'FBeta',
|
||||
'Precision',
|
||||
'Recall',
|
||||
'PrecisionRecallCurve',
|
||||
'ROC',
|
||||
'AUROC'
|
||||
]
|
||||
|
||||
|
||||
class SklearnMetric(NumpyMetric):
|
||||
"""
|
||||
Bridge between PyTorch Lightning and scikit-learn metrics
|
||||
|
||||
Warning:
|
||||
Every metric call will cause a GPU synchronization, which may slow down your code
|
||||
|
||||
Note:
|
||||
The order of targets and predictions may be different from the order typically used in PyTorch
|
||||
"""
|
||||
def __init__(self, metric_name: str,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
metric_name: the metric name to import and compute from scikit-learn.metrics
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
**kwargs: additonal keyword arguments (will be forwarded to metric call)
|
||||
"""
|
||||
super().__init__(name=metric_name, reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.metric_kwargs = kwargs
|
||||
lightning_logger.debug(
|
||||
f'Metric {self.__class__.__name__} is using Sklearn as backend, meaning that'
|
||||
' every metric call will cause a GPU synchronization, which may slow down your code'
|
||||
)
|
||||
|
||||
@property
|
||||
def metric_fn(self):
|
||||
import sklearn.metrics
|
||||
return getattr(sklearn.metrics, self.name)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Union[np.ndarray, int, float]:
|
||||
"""
|
||||
Carries the actual metric computation
|
||||
|
||||
Args:
|
||||
*args: Positional arguments forwarded to metric call (should be already converted to numpy)
|
||||
**kwargs: keyword arguments forwarded to metric call (should be already converted to numpy)
|
||||
|
||||
Return:
|
||||
the metric value (will be converted to tensor by baseclass)
|
||||
|
||||
"""
|
||||
return self.metric_fn(*args, **kwargs, **self.metric_kwargs)
|
||||
|
||||
|
||||
class Accuracy(SklearnMetric):
|
||||
"""
|
||||
Calculates the Accuracy Score
|
||||
|
||||
Warning:
|
||||
Every metric call will cause a GPU synchronization, which may slow down your code
|
||||
"""
|
||||
def __init__(self, normalize: bool = True,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
normalize: If ``False``, return the number of correctly classified samples.
|
||||
Otherwise, return the fraction of correctly classified samples.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(metric_name='accuracy_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
normalize=normalize)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> float:
|
||||
"""
|
||||
Computes the accuracy
|
||||
|
||||
Args:
|
||||
y_pred: the array containing the predictions (already in categorical form)
|
||||
y_true: the array containing the targets (in categorical form)
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
Accuracy Score
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class AUC(SklearnMetric):
|
||||
"""
|
||||
Calculates the Area Under the Curve using the trapoezoidal rule
|
||||
|
||||
Warning:
|
||||
Every metric call will cause a GPU synchronization, which may slow down your code
|
||||
"""
|
||||
def __init__(self,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
|
||||
super().__init__(metric_name='auc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
)
|
||||
|
||||
def forward(self, x: np.ndarray, y: np.ndarray) -> float:
|
||||
"""
|
||||
Computes the AUC
|
||||
|
||||
Args:
|
||||
x: x coordinates.
|
||||
y: y coordinates.
|
||||
|
||||
Return:
|
||||
AUC calculated with trapezoidal rule
|
||||
|
||||
"""
|
||||
return super().forward(x=x, y=y)
|
||||
|
||||
|
||||
class AveragePrecision(SklearnMetric):
|
||||
"""
|
||||
Calculates the average precision (AP) score.
|
||||
"""
|
||||
def __init__(self, average: Optional[str] = 'macro',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
average: If None, the scores for each class are returned. Otherwise, this determines the type of
|
||||
averaging performed on the data:
|
||||
|
||||
* If 'micro': Calculate metrics globally by considering each element of the label indicator
|
||||
matrix as a label.
|
||||
* If 'macro': Calculate metrics for each label, and find their unweighted mean.
|
||||
This does not take label imbalance into account.
|
||||
* If 'weighted': Calculate metrics for each label, and find their average, weighted by
|
||||
support (the number of true instances for each label).
|
||||
* If 'samples': Calculate metrics for each instance, and find their average.
|
||||
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('average_precision_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_score: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> float:
|
||||
"""
|
||||
Args:
|
||||
y_score: Target scores, can either be probability estimates of the positive class,
|
||||
confidence values, or binary decisions.
|
||||
y_true: True binary labels in binary label indicators.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
average precision score
|
||||
"""
|
||||
return super().forward(y_score=y_score, y_true=y_true,
|
||||
sample_weight=sample_weight)
|
||||
|
||||
|
||||
class ConfusionMatrix(SklearnMetric):
|
||||
"""
|
||||
Compute confusion matrix to evaluate the accuracy of a classification
|
||||
By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`
|
||||
is equal to the number of observations known to be in group :math:`i` but
|
||||
predicted to be in group :math:`j`.
|
||||
"""
|
||||
def __init__(self, labels: Optional[Sequence] = None,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
labels: List of labels to index the matrix. This may be used to reorder
|
||||
or select a subset of labels.
|
||||
If none is given, those that appear at least once
|
||||
in ``y_true`` or ``y_pred`` are used in sorted order.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('confusion_matrix',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
labels=labels)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
y_pred: Estimated targets as returned by a classifier.
|
||||
y_true: Ground truth (correct) target values.
|
||||
|
||||
Return:
|
||||
Confusion matrix (array of shape [n_classes, n_classes])
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true)
|
||||
|
||||
|
||||
class F1(SklearnMetric):
|
||||
r"""
|
||||
Compute the F1 score, also known as balanced F-score or F-measure
|
||||
The F1 score can be interpreted as a weighted average of the precision and
|
||||
recall, where an F1 score reaches its best value at 1 and worst score at 0.
|
||||
The relative contribution of precision and recall to the F1 score are
|
||||
equal. The formula for the F1 score is:
|
||||
|
||||
.. math::
|
||||
|
||||
F_1 = 2 \cdot \frac{precision \cdot recall}{precision + recall}
|
||||
|
||||
In the multi-class and multi-label case, this is the weighted average of
|
||||
the F1 score of each class.
|
||||
|
||||
References
|
||||
- [1] `Wikipedia entry for the F1-score
|
||||
<http://en.wikipedia.org/wiki/F1_score>`_
|
||||
"""
|
||||
|
||||
def __init__(self, labels: Optional[Sequence] = None,
|
||||
pos_label: Union[str, int] = 1,
|
||||
average: Optional[str] = 'binary',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
labels: Integer array of labels.
|
||||
pos_label: The class to report if ``average='binary'``.
|
||||
average: This parameter is required for multiclass/multilabel targets.
|
||||
If ``None``, the scores for each class are returned. Otherwise, this
|
||||
determines the type of averaging performed on the data:
|
||||
|
||||
* ``'binary'``:
|
||||
Only report results for the class specified by ``pos_label``.
|
||||
This is applicable only if targets (``y_{true,pred}``) are binary.
|
||||
* ``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
* ``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
* ``'weighted'``:
|
||||
Calculate metrics for each label, and find their average, weighted
|
||||
by support (the number of true instances for each label). This
|
||||
alters 'macro' to account for label imbalance; it can result in an
|
||||
F-score that is not between precision and recall.
|
||||
* ``'samples'``:
|
||||
Calculate metrics for each instance, and find their average (only
|
||||
meaningful for multilabel classification where this differs from
|
||||
:func:`accuracy_score`).
|
||||
|
||||
Note that if ``pos_label`` is given in binary classification with
|
||||
`average != 'binary'`, only that positive class is reported. This
|
||||
behavior is deprecated and will change in version 0.18.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('f1_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
labels=labels,
|
||||
pos_label=pos_label,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
y_pred : Estimated targets as returned by a classifier.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
F1 score of the positive class in binary classification or weighted
|
||||
average of the F1 scores of each class for the multiclass task.
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class FBeta(SklearnMetric):
|
||||
"""
|
||||
Compute the F-beta score. The `beta` parameter determines the weight of precision in the combined
|
||||
score. ``beta < 1`` lends more weight to precision, while ``beta > 1``
|
||||
favors recall (``beta -> 0`` considers only precision, ``beta -> inf``
|
||||
only recall).
|
||||
|
||||
References:
|
||||
- [1] R. Baeza-Yates and B. Ribeiro-Neto (2011).
|
||||
Modern Information Retrieval. Addison Wesley, pp. 327-328.
|
||||
- [2] `Wikipedia entry for the F1-score
|
||||
<http://en.wikipedia.org/wiki/F1_score>`_
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float, labels: Optional[Sequence] = None,
|
||||
pos_label: Union[str, int] = 1,
|
||||
average: Optional[str] = 'binary',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
beta: Weight of precision in harmonic mean.
|
||||
labels: Integer array of labels.
|
||||
pos_label: The class to report if ``average='binary'``.
|
||||
average: This parameter is required for multiclass/multilabel targets.
|
||||
If ``None``, the scores for each class are returned. Otherwise, this
|
||||
determines the type of averaging performed on the data:
|
||||
|
||||
* ``'binary'``:
|
||||
Only report results for the class specified by ``pos_label``.
|
||||
This is applicable only if targets (``y_{true,pred}``) are binary.
|
||||
* ``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
* ``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
* ``'weighted'``:
|
||||
Calculate metrics for each label, and find their average, weighted
|
||||
by support (the number of true instances for each label). This
|
||||
alters 'macro' to account for label imbalance; it can result in an
|
||||
F-score that is not between precision and recall.
|
||||
* ``'samples'``:
|
||||
Calculate metrics for each instance, and find their average (only
|
||||
meaningful for multilabel classification where this differs from
|
||||
:func:`accuracy_score`).
|
||||
|
||||
Note that if ``pos_label`` is given in binary classification with
|
||||
`average != 'binary'`, only that positive class is reported. This
|
||||
behavior is deprecated and will change in version 0.18.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('fbeta_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
beta=beta,
|
||||
labels=labels,
|
||||
pos_label=pos_label,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
y_pred : Estimated targets as returned by a classifier.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
|
||||
Return:
|
||||
FBeta score of the positive class in binary classification or weighted
|
||||
average of the FBeta scores of each class for the multiclass task.
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class Precision(SklearnMetric):
|
||||
"""
|
||||
Compute the precision
|
||||
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
|
||||
true positives and ``fp`` the number of false positives. The precision is
|
||||
intuitively the ability of the classifier not to label as positive a sample
|
||||
that is negative.
|
||||
The best value is 1 and the worst value is 0.
|
||||
"""
|
||||
|
||||
def __init__(self, labels: Optional[Sequence] = None,
|
||||
pos_label: Union[str, int] = 1,
|
||||
average: Optional[str] = 'binary',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
labels: Integer array of labels.
|
||||
pos_label: The class to report if ``average='binary'``.
|
||||
average: This parameter is required for multiclass/multilabel targets.
|
||||
If ``None``, the scores for each class are returned. Otherwise, this
|
||||
determines the type of averaging performed on the data:
|
||||
|
||||
* ``'binary'``:
|
||||
Only report results for the class specified by ``pos_label``.
|
||||
This is applicable only if targets (``y_{true,pred}``) are binary.
|
||||
* ``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
* ``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
* ``'weighted'``:
|
||||
Calculate metrics for each label, and find their average, weighted
|
||||
by support (the number of true instances for each label). This
|
||||
alters 'macro' to account for label imbalance; it can result in an
|
||||
F-score that is not between precision and recall.
|
||||
* ``'samples'``:
|
||||
Calculate metrics for each instance, and find their average (only
|
||||
meaningful for multilabel classification where this differs from
|
||||
:func:`accuracy_score`).
|
||||
|
||||
Note that if ``pos_label`` is given in binary classification with
|
||||
`average != 'binary'`, only that positive class is reported. This
|
||||
behavior is deprecated and will change in version 0.18.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('precision_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
labels=labels,
|
||||
pos_label=pos_label,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
y_pred : Estimated targets as returned by a classifier.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
Precision of the positive class in binary classification or weighted
|
||||
average of the precision of each class for the multiclass task.
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class Recall(SklearnMetric):
|
||||
"""
|
||||
Compute the recall
|
||||
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
|
||||
true positives and ``fn`` the number of false negatives. The recall is
|
||||
intuitively the ability of the classifier to find all the positive samples.
|
||||
The best value is 1 and the worst value is 0.
|
||||
"""
|
||||
|
||||
def __init__(self, labels: Optional[Sequence] = None,
|
||||
pos_label: Union[str, int] = 1,
|
||||
average: Optional[str] = 'binary',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
labels: Integer array of labels.
|
||||
pos_label: The class to report if ``average='binary'``.
|
||||
average: This parameter is required for multiclass/multilabel targets.
|
||||
If ``None``, the scores for each class are returned. Otherwise, this
|
||||
determines the type of averaging performed on the data:
|
||||
|
||||
* ``'binary'``:
|
||||
Only report results for the class specified by ``pos_label``.
|
||||
This is applicable only if targets (``y_{true,pred}``) are binary.
|
||||
* ``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
* ``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
* ``'weighted'``:
|
||||
Calculate metrics for each label, and find their average, weighted
|
||||
by support (the number of true instances for each label). This
|
||||
alters 'macro' to account for label imbalance; it can result in an
|
||||
F-score that is not between precision and recall.
|
||||
* ``'samples'``:
|
||||
Calculate metrics for each instance, and find their average (only
|
||||
meaningful for multilabel classification where this differs from
|
||||
:func:`accuracy_score`).
|
||||
|
||||
Note that if ``pos_label`` is given in binary classification with
|
||||
`average != 'binary'`, only that positive class is reported. This
|
||||
behavior is deprecated and will change in version 0.18.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('recall_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
labels=labels,
|
||||
pos_label=pos_label,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
y_pred : Estimated targets as returned by a classifier.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
Recall of the positive class in binary classification or weighted
|
||||
average of the recall of each class for the multiclass task.
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class PrecisionRecallCurve(SklearnMetric):
|
||||
"""
|
||||
Compute precision-recall pairs for different probability thresholds
|
||||
|
||||
Note:
|
||||
This implementation is restricted to the binary classification task.
|
||||
|
||||
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
|
||||
true positives and ``fp`` the number of false positives. The precision is
|
||||
intuitively the ability of the classifier not to label as positive a sample
|
||||
that is negative.
|
||||
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
|
||||
true positives and ``fn`` the number of false negatives. The recall is
|
||||
intuitively the ability of the classifier to find all the positive samples.
|
||||
The last precision and recall values are 1. and 0. respectively and do not
|
||||
have a corresponding threshold. This ensures that the graph starts on the
|
||||
x axis.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pos_label: Union[str, int] = 1,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
pos_label: The class to report if ``average='binary'``.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
pos_label=pos_label)
|
||||
|
||||
def forward(self, probas_pred: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
probas_pred : Estimated probabilities or decision function.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Returns:
|
||||
precision:
|
||||
Precision values such that element i is the precision of
|
||||
predictions with score >= thresholds[i] and the last element is 1.
|
||||
recall:
|
||||
Decreasing recall values such that element i is the recall of
|
||||
predictions with score >= thresholds[i] and the last element is 0.
|
||||
thresholds:
|
||||
Increasing thresholds on the decision function used to compute
|
||||
precision and recall.
|
||||
|
||||
"""
|
||||
# only return x and y here, since for now we cannot auto-convert elements of multiple length.
|
||||
# Will be fixed in native implementation
|
||||
return np.array(
|
||||
super().forward(probas_pred=probas_pred, y_true=y_true, sample_weight=sample_weight)[:2])
|
||||
|
||||
|
||||
class ROC(SklearnMetric):
|
||||
"""
|
||||
Compute Receiver operating characteristic (ROC)
|
||||
|
||||
Note:
|
||||
this implementation is restricted to the binary classification task.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pos_label: Union[str, int] = 1,
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM):
|
||||
"""
|
||||
Args:
|
||||
pos_labels: The class to report if ``average='binary'``.
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
|
||||
References:
|
||||
- [1] `Wikipedia entry for the Receiver operating characteristic
|
||||
<http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
|
||||
"""
|
||||
super().__init__('roc_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
pos_label=pos_label)
|
||||
|
||||
def forward(self, y_score: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Args:
|
||||
y_score : Target scores, can either be probability estimates of the positive
|
||||
class or confidence values.
|
||||
y_true: Ground truth (correct) target values.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Returns:
|
||||
fpr:
|
||||
Increasing false positive rates such that element i is the false
|
||||
positive rate of predictions with score >= thresholds[i].
|
||||
tpr:
|
||||
Increasing true positive rates such that element i is the true
|
||||
positive rate of predictions with score >= thresholds[i].
|
||||
thresholds:
|
||||
Decreasing thresholds on the decision function used to compute
|
||||
fpr and tpr. `thresholds[0]` represents no instances being predicted
|
||||
and is arbitrarily set to `max(y_score) + 1`.
|
||||
|
||||
"""
|
||||
return np.array(super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight)[:2])
|
||||
|
||||
|
||||
class AUROC(SklearnMetric):
|
||||
"""
|
||||
Compute Area Under the Curve (AUC) from prediction scores
|
||||
|
||||
Note:
|
||||
this implementation is restricted to the binary classification task
|
||||
or multilabel classification task in label indicator format.
|
||||
"""
|
||||
|
||||
def __init__(self, average: Optional[str] = 'macro',
|
||||
reduce_group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: Any = torch.distributed.ReduceOp.SUM
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
average: If None, the scores for each class are returned. Otherwise, this determines the type of
|
||||
averaging performed on the data:
|
||||
|
||||
* If 'micro': Calculate metrics globally by considering each element of the label indicator
|
||||
matrix as a label.
|
||||
* If 'macro': Calculate metrics for each label, and find their unweighted mean.
|
||||
This does not take label imbalance into account.
|
||||
* If 'weighted': Calculate metrics for each label, and find their average, weighted by
|
||||
support (the number of true instances for each label).
|
||||
* If 'samples': Calculate metrics for each instance, and find their average.
|
||||
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__('roc_auc_score',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op,
|
||||
average=average)
|
||||
|
||||
def forward(self, y_score: np.ndarray, y_true: np.ndarray,
|
||||
sample_weight: Optional[np.ndarray] = None) -> float:
|
||||
"""
|
||||
Args:
|
||||
y_score: Target scores, can either be probability estimates of the positive class,
|
||||
confidence values, or binary decisions.
|
||||
y_true: True binary labels in binary label indicators.
|
||||
sample_weight: Sample weights.
|
||||
|
||||
Return:
|
||||
Area Under Receiver Operating Characteristic Curve
|
||||
"""
|
||||
return super().forward(y_score=y_score, y_true=y_true,
|
||||
sample_weight=sample_weight)
|
|
@ -0,0 +1,130 @@
|
|||
import numbers
|
||||
from typing import Union, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import default_convert
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def _apply_to_inputs(func_to_apply, *dec_args, **dec_kwargs):
|
||||
def decorator_fn(func_to_decorate):
|
||||
def new_func(*args, **kwargs):
|
||||
args = func_to_apply(args, *dec_args, **dec_kwargs)
|
||||
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
|
||||
return func_to_decorate(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _apply_to_outputs(func_to_apply, *dec_args, **dec_kwargs):
|
||||
def decorator_fn(function_to_decorate):
|
||||
def new_func(*args, **kwargs):
|
||||
result = function_to_decorate(*args, **kwargs)
|
||||
return func_to_apply(result, *dec_args, **dec_kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _convert_to_tensor(data: Any) -> Any:
|
||||
"""
|
||||
Maps all kind of collections and numbers to tensors
|
||||
|
||||
Args:
|
||||
data: the data to convert to tensor
|
||||
|
||||
Returns:
|
||||
the converted data
|
||||
|
||||
"""
|
||||
if isinstance(data, numbers.Number):
|
||||
return torch.tensor([data])
|
||||
else:
|
||||
return default_convert(data)
|
||||
|
||||
|
||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||
"""
|
||||
converts all tensors and numpy arrays to numpy arrays
|
||||
Args:
|
||||
data: the tensor or array to convert to numpy
|
||||
|
||||
Returns:
|
||||
the resulting numpy array
|
||||
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.cpu().detach().numpy()
|
||||
elif isinstance(data, numbers.Number):
|
||||
return np.array([data])
|
||||
return data
|
||||
|
||||
|
||||
def _numpy_metric_conversion(func_to_decorate):
|
||||
# Applies collection conversion from tensor to numpy to all inputs
|
||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
return func_convert_in_out
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate):
|
||||
# Converts all inputs to tensor if possible
|
||||
func_convert_inputs = _apply_to_inputs(_convert_to_tensor)(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
|
||||
|
||||
def _sync_ddp(result: Union[torch.Tensor],
|
||||
group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
device: the device to put the synced and reduced value to
|
||||
dtype: the datatype to convert the synced and reduced value to
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
reduced value
|
||||
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group,
|
||||
async_op=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def numpy_metric(group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_metric(group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
class DeviceDtypeModuleMixin(torch.nn.Module):
|
||||
class DeviceDtypeModuleMixin(Module):
|
||||
_device: ...
|
||||
_dtype: Union[str, torch.dtype]
|
||||
|
||||
|
@ -25,7 +26,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
# Necessary to avoid infinite recursion
|
||||
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')
|
||||
|
||||
def to(self, *args, **kwargs) -> torch.nn.Module:
|
||||
def to(self, *args, **kwargs) -> Module:
|
||||
"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
|
@ -91,7 +92,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device: Optional[int] = None) -> torch.nn.Module:
|
||||
def cuda(self, device: Optional[int] = None) -> Module:
|
||||
"""Moves all model parameters and buffers to the GPU.
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
it should be called before constructing optimizer if the module will
|
||||
|
@ -108,7 +109,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
self._device = torch.device('cuda', index=device)
|
||||
return super().cuda(device=device)
|
||||
|
||||
def cpu(self) -> torch.nn.Module:
|
||||
def cpu(self) -> Module:
|
||||
"""Moves all model parameters and buffers to the CPU.
|
||||
Returns:
|
||||
Module: self
|
||||
|
@ -116,7 +117,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
self._device = torch.device('cpu')
|
||||
return super().cpu()
|
||||
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module:
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Module:
|
||||
"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
Arguments:
|
||||
|
@ -128,7 +129,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
self._dtype = dst_type
|
||||
return super().type(dst_type=dst_type)
|
||||
|
||||
def float(self) -> torch.nn.Module:
|
||||
def float(self) -> Module:
|
||||
"""Casts all floating point parameters and buffers to float datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -137,7 +138,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
self._dtype = torch.float
|
||||
return super().float()
|
||||
|
||||
def double(self) -> torch.nn.Module:
|
||||
def double(self) -> Module:
|
||||
"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
Returns:
|
||||
|
@ -146,7 +147,7 @@ class DeviceDtypeModuleMixin(torch.nn.Module):
|
|||
self._dtype = torch.double
|
||||
return super().double()
|
||||
|
||||
def half(self) -> torch.nn.Module:
|
||||
def half(self) -> Module:
|
||||
"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -9,4 +9,6 @@ trains>=0.14.1
|
|||
matplotlib>=3.1.1
|
||||
# no need to install with [pytorch] as pytorch is already installed and torchvision is required only for Horovod examples
|
||||
horovod>=0.19.1
|
||||
omegaconf==2.0.0
|
||||
omegaconf>=2.0.0
|
||||
# scipy>=0.13.3
|
||||
scikit-learn>=0.20.0
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# the default package dependencies
|
||||
|
||||
numpy>=1.15 # because some BLAS compilation issues
|
||||
tqdm>=4.41.0
|
||||
numpy>=1.16.4
|
||||
torch>=1.3
|
||||
tensorboard>=1.14
|
||||
future>=0.17.1 # required for builtins in setup.py
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
import numbers
|
||||
from collections import Mapping, Sequence
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import (accuracy_score, average_precision_score, auc, confusion_matrix, f1_score,
|
||||
fbeta_score, precision_score, recall_score, precision_recall_curve, roc_curve,
|
||||
roc_auc_score)
|
||||
|
||||
from pytorch_lightning.metrics.converters import _convert_to_numpy
|
||||
from pytorch_lightning.metrics.sklearn import (Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def xy_only(func):
|
||||
def new_func(*args, **kwargs):
|
||||
return np.array(func(*args, **kwargs)[:2])
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
|
||||
pytest.param(Accuracy(), accuracy_score,
|
||||
{'y_pred': torch.randint(low=0, high=10, size=(128,)),
|
||||
'y_true': torch.randint(low=0, high=10, size=(128,))}, id='Accuracy'),
|
||||
pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10,
|
||||
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2,
|
||||
0.2, 0.3, 0.5, 0.6, 0.7])}, id='AUC'),
|
||||
pytest.param(AveragePrecision(), average_precision_score,
|
||||
{'y_score': torch.randint(2, size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='AveragePrecision'),
|
||||
pytest.param(ConfusionMatrix(), confusion_matrix,
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='ConfusionMatrix'),
|
||||
pytest.param(F1(average='macro'), partial(f1_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='F1'),
|
||||
pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='FBeta'),
|
||||
pytest.param(Precision(average='macro'), partial(precision_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='Precision'),
|
||||
pytest.param(Recall(average='macro'), partial(recall_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='Recall'),
|
||||
pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve),
|
||||
{'probas_pred': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='PrecisionRecallCurve'),
|
||||
pytest.param(ROC(), xy_only(roc_curve),
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='ROC'),
|
||||
pytest.param(AUROC(), roc_auc_score,
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='AUROC'),
|
||||
])
|
||||
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
|
||||
numpy_inputs = apply_to_collection(
|
||||
inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)
|
||||
|
||||
sklearn_result = sklearn_func(**numpy_inputs)
|
||||
lightning_result = metric_class(**inputs)
|
||||
|
||||
sklearn_result = apply_to_collection(
|
||||
sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)
|
||||
|
||||
lightning_result = apply_to_collection(
|
||||
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)
|
||||
|
||||
assert isinstance(lightning_result, type(sklearn_result))
|
||||
|
||||
if isinstance(lightning_result, np.ndarray):
|
||||
assert np.allclose(lightning_result, sklearn_result)
|
||||
elif isinstance(lightning_result, Mapping):
|
||||
for key in lightning_result.keys():
|
||||
assert np.allclose(lightning_result[key], sklearn_result[key])
|
||||
|
||||
elif isinstance(lightning_result, Sequence):
|
||||
for val_lightning, val_sklearn in zip(lightning_result, sklearn_result):
|
||||
assert np.allclose(val_lightning, val_sklearn)
|
||||
|
||||
else:
|
||||
raise TypeError
|
|
@ -8,4 +8,4 @@ flake8-black
|
|||
check-manifest
|
||||
twine==1.13.0
|
||||
black==19.10b0
|
||||
pre-commit>=1.21.0
|
||||
pre-commit>=1.0
|
||||
|
|
Loading…
Reference in New Issue