* New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
ac76dfcf62
commit
9b629637b8
|
@ -54,6 +54,11 @@ jobs:
|
|||
run: |
|
||||
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)"
|
||||
|
||||
# versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996)
|
||||
- name: Setup MacOS Minimal
|
||||
if: runner.os == 'macOS' && matrix.requires ='minimal'
|
||||
run : |
|
||||
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch>=1.4') ; open('requirements.txt', 'w').write(req)"
|
||||
- name: Set min. dependencies
|
||||
if: matrix.requires == 'minimal'
|
||||
run: |
|
||||
|
@ -137,4 +142,4 @@ jobs:
|
|||
- name: Statistics
|
||||
if: success()
|
||||
run: |
|
||||
coverage report
|
||||
coverage report
|
||||
|
|
|
@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
|
||||
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -32,7 +32,7 @@ removed until codecov badge isn't empy. likely a config error showing nothing on
|
|||
| Linux py3.6 [CPU] | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) |
|
||||
| Linux py3.7 [GPU] | - | - | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) |
|
||||
| Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |
|
||||
| OSX py3.6 / py3.7 / py3.8| [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |
|
||||
| OSX py3.6 / py3.7 / py3.8| - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |
|
||||
| Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
|
||||
|
||||
</center>
|
||||
|
|
|
@ -23,6 +23,7 @@ PyTorch Lightning Documentation
|
|||
hooks
|
||||
lightning-module
|
||||
loggers
|
||||
metrics
|
||||
trainer
|
||||
|
||||
.. toctree::
|
||||
|
@ -115,6 +116,7 @@ Indices and tables
|
|||
api/pytorch_lightning.core
|
||||
api/pytorch_lightning.callbacks
|
||||
api/pytorch_lightning.loggers
|
||||
api/pytorch_lightning.metrics
|
||||
api/pytorch_lightning.overrides
|
||||
api/pytorch_lightning.profiler
|
||||
api/pytorch_lightning.trainer
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
.. automodule:: pytorch_lightning.metrics
|
||||
:members:
|
||||
:noindex:
|
||||
:exclude-members:
|
|
@ -18,7 +18,7 @@ from pytorch_lightning.core.grads import GradInformation
|
|||
from pytorch_lightning.core.hooks import ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
|
||||
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
"""
|
||||
Metrics
|
||||
=======
|
||||
|
||||
Metrics are generally used to monitor model performance.
|
||||
|
||||
The following package aims to provide the most convenient ones as well
|
||||
as a structure to implement your custom metrics for all the fancy research
|
||||
you want to do.
|
||||
|
||||
For native PyTorch implementations of metrics, it is recommended to use
|
||||
the :class:`TensorMetric` which handles automated DDP syncing and conversions
|
||||
to tensors for all inputs and outputs.
|
||||
|
||||
If your metrics implementation works on numpy, just use the
|
||||
:class:`NumpyMetric`, which handles the automated conversion of
|
||||
inputs to and outputs from numpy as well as automated ddp syncing.
|
||||
|
||||
.. warning:: Employing numpy in your metric calculation might slow
|
||||
down your training substantially, since every metric computation
|
||||
requires a GPU sync to convert tensors to numpy.
|
||||
|
||||
|
||||
"""
|
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
This file provides functions and decorators for automated input and output
|
||||
conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to
|
||||
sync tensors between different processes in a DDP scenario, when needed.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numbers
|
||||
from typing import Union, Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import np_str_obj_array_pattern
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all inputs of a function.
|
||||
Args:
|
||||
func_to_apply: the function to apply to the inputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
# actual function applying the give function to inputs
|
||||
def new_func(*args, **kwargs):
|
||||
args = func_to_apply(args, *dec_args, **dec_kwargs)
|
||||
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
|
||||
return func_to_decorate(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all outputs of a function.
|
||||
Args:
|
||||
func_to_apply: the function to apply to the outputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(function_to_decorate):
|
||||
# actual function applying the give function to outputs
|
||||
def new_func(*args, **kwargs):
|
||||
result = function_to_decorate(*args, **kwargs)
|
||||
return func_to_apply(result, *dec_args, **dec_kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _convert_to_tensor(data: Any) -> Any:
|
||||
"""
|
||||
Maps all kind of collections and numbers to tensors.
|
||||
|
||||
Args:
|
||||
data: the data to convert to tensor
|
||||
|
||||
Returns:
|
||||
the converted data
|
||||
|
||||
"""
|
||||
if isinstance(data, numbers.Number):
|
||||
return torch.tensor([data])
|
||||
# is not array of object
|
||||
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data
|
||||
|
||||
raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")
|
||||
|
||||
|
||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||
"""Convert all tensors and numpy arrays to numpy arrays.
|
||||
Args:
|
||||
data: the tensor or array to convert to numpy
|
||||
|
||||
Returns:
|
||||
the resulting numpy array
|
||||
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.cpu().detach().numpy()
|
||||
elif isinstance(data, numbers.Number):
|
||||
return np.array([data])
|
||||
elif isinstance(data, np.ndarray):
|
||||
return data
|
||||
|
||||
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
|
||||
|
||||
|
||||
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator handling the argument conversion for metrics working on numpy.
|
||||
All inputs of the decorated function will be converted to numpy and all
|
||||
outputs will be converted to tensors.
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
# applies collection conversion from tensor to numpy to all inputs
|
||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
return func_convert_in_out
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator Handling the argument conversion for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
# converts all inputs to tensor if possible
|
||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
|
||||
|
||||
def _sync_ddp_if_available(result: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
|
||||
Returns:
|
||||
reduced value
|
||||
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
if reduce_op is None:
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group,
|
||||
async_op=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def numpy_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on numpy arrays.
|
||||
|
||||
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
||||
All inputs of the decorated function will be converted to numpy and all
|
||||
outputs will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors.
|
||||
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
|
@ -0,0 +1,103 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
|
||||
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
|
||||
|
||||
|
||||
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for metric implementation.
|
||||
|
||||
Should be used to implement metrics that
|
||||
1. Return multiple Outputs
|
||||
2. Handle their own DDP sync
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Args:
|
||||
name: the metric's name
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self._dtype = torch.get_default_dtype()
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Implements the actual metric computation.
|
||||
|
||||
Returns:
|
||||
metric value
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TensorMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating directly on tensors.
|
||||
All inputs and outputs will be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._orig_call = tensor_metric(group=reduce_group,
|
||||
reduce_op=reduce_op)(super().__call__)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
||||
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
|
||||
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
||||
_to_device_dtype)
|
||||
|
||||
|
||||
class NumpyMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating on numpy arrays.
|
||||
All inputs will be casted to numpy if necessary and all outputs will
|
||||
be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._orig_call = numpy_metric(group=reduce_group,
|
||||
reduce_op=reduce_op)(super().__call__)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
||||
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
|
||||
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
||||
_to_device_dtype)
|
|
@ -0,0 +1,36 @@
|
|||
from collections import Mapping, Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
|
||||
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Recursively applies a function to all elements of a certain dtype.
|
||||
|
||||
Args:
|
||||
data: the collection to apply the function to
|
||||
dtype: the given function will be applied to all elements of this dtype
|
||||
function: the function to apply
|
||||
*args: positional arguments (will be forwarded to calls of ``function``)
|
||||
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
|
||||
|
||||
Returns:
|
||||
the resulting collection
|
||||
|
||||
"""
|
||||
elem_type = type(data)
|
||||
|
||||
# Breaking condition
|
||||
if isinstance(data, dtype):
|
||||
return function(data, *args, **kwargs)
|
||||
|
||||
# Recursively apply to collection items
|
||||
elif isinstance(data, Mapping):
|
||||
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
|
||||
for k, v in data.items()})
|
||||
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
|
||||
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
|
||||
elif isinstance(data, Sequence) and not isinstance(data, str):
|
||||
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
|
||||
|
||||
# data is neither of dtype, nor a collection
|
||||
return data
|
|
@ -0,0 +1,214 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
_apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy,
|
||||
_numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['args', 'kwargs'],
|
||||
[pytest.param([], {}),
|
||||
pytest.param([1., 2.], {}),
|
||||
pytest.param([], {'a': 1., 'b': 2.}),
|
||||
pytest.param([1., 2.], {'a': 1., 'b': 2.})])
|
||||
def test_apply_to_inputs(args, kwargs):
|
||||
def apply_fn(inputs, factor):
|
||||
if isinstance(inputs, (float, int)):
|
||||
return inputs * factor
|
||||
elif isinstance(inputs, dict):
|
||||
return {k: apply_fn(v, factor) for k, v in inputs.items()}
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
return [apply_fn(x, factor) for x in inputs]
|
||||
|
||||
@_apply_to_inputs(apply_fn, factor=2.)
|
||||
def test_fn(*func_args, **func_kwargs):
|
||||
return func_args, func_kwargs
|
||||
|
||||
result_args, result_kwargs = test_fn(*args, **kwargs)
|
||||
assert isinstance(result_args, (list, tuple))
|
||||
assert isinstance(result_kwargs, dict)
|
||||
assert len(result_args) == len(args)
|
||||
assert len(result_kwargs) == len(kwargs)
|
||||
assert all([k in result_kwargs for k in kwargs.keys()])
|
||||
for arg, result_arg in zip(args, result_args):
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
for key in kwargs.keys():
|
||||
arg = kwargs[key]
|
||||
result_arg = result_kwargs[key]
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
|
||||
def test_apply_to_outputs():
|
||||
def apply_fn(inputs, additional_str):
|
||||
return str(inputs) + additional_str
|
||||
|
||||
@_apply_to_outputs(apply_fn, additional_str='_str')
|
||||
def test_fn(*args, **kwargs):
|
||||
return 'dummy'
|
||||
|
||||
assert test_fn() == 'dummy_str'
|
||||
|
||||
|
||||
def test_convert_to_tensor():
|
||||
for test_item in [1., np.array([1.])]:
|
||||
result_tensor = _convert_to_tensor(test_item)
|
||||
assert isinstance(result_tensor, torch.Tensor)
|
||||
assert result_tensor.item() == 1.
|
||||
|
||||
|
||||
def test_convert_to_numpy():
|
||||
for test_item in [1., torch.tensor([1.])]:
|
||||
result = _convert_to_numpy(test_item)
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.item() == 1.
|
||||
|
||||
|
||||
def test_numpy_metric_conversion():
|
||||
@_numpy_metric_conversion
|
||||
def numpy_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, np.ndarray)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, np.ndarray)
|
||||
|
||||
return 5.
|
||||
|
||||
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5.
|
||||
|
||||
|
||||
def test_tensor_metric_conversion():
|
||||
@_tensor_metric_conversion
|
||||
def tensor_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, torch.Tensor)
|
||||
|
||||
return 5.
|
||||
|
||||
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5.
|
||||
|
||||
|
||||
def setup_ddp(rank, worldsize, ):
|
||||
import os
|
||||
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def ddp_test_fn(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
tensor = torch.tensor([1.], device='cuda:0')
|
||||
|
||||
reduced_tensor = _sync_ddp_if_available(tensor)
|
||||
|
||||
assert reduced_tensor.item() == dist.get_world_size(), \
|
||||
'Sync-Reduce does not work properly with DDP and Tensors'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_sync_reduce_ddp():
|
||||
"""Make sure sync-reduce works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
||||
|
||||
|
||||
def test_sync_reduce_simple():
|
||||
"""Make sure sync-reduce works without DDP"""
|
||||
tensor = torch.tensor([1.], device='cpu')
|
||||
|
||||
reduced_tensor = _sync_ddp_if_available(tensor)
|
||||
|
||||
assert torch.allclose(tensor, reduced_tensor), \
|
||||
'Sync-Reduce does not work properly without DDP and Tensors'
|
||||
|
||||
|
||||
def _test_tensor_metric(is_ddp: bool):
|
||||
@tensor_metric()
|
||||
def tensor_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, torch.Tensor)
|
||||
|
||||
return 5.
|
||||
|
||||
if is_ddp:
|
||||
factor = dist.get_world_size()
|
||||
else:
|
||||
factor = 1.
|
||||
|
||||
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5. * factor
|
||||
|
||||
|
||||
def _ddp_test_tensor_metric(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
_test_tensor_metric(True)
|
||||
|
||||
|
||||
def test_tensor_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
|
||||
|
||||
|
||||
def test_tensor_metric_simple():
|
||||
_test_tensor_metric(False)
|
||||
|
||||
|
||||
def _test_numpy_metric(is_ddp: bool):
|
||||
@numpy_metric()
|
||||
def numpy_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, np.ndarray)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, np.ndarray)
|
||||
|
||||
return 5.
|
||||
|
||||
if is_ddp:
|
||||
factor = dist.get_world_size()
|
||||
else:
|
||||
factor = 1.
|
||||
|
||||
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5. * factor
|
||||
|
||||
|
||||
def _ddp_test_numpy_metric(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
_test_numpy_metric(True)
|
||||
|
||||
|
||||
def test_numpy_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
|
||||
|
||||
|
||||
def test_numpy_metric_simple():
|
||||
_test_tensor_metric(False)
|
|
@ -0,0 +1,85 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
|
||||
|
||||
class DummyTensorMetric(TensorMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return 1.
|
||||
|
||||
|
||||
class DummyNumpyMetric(NumpyMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, np.ndarray)
|
||||
assert isinstance(input2, np.ndarray)
|
||||
return 1.
|
||||
|
||||
|
||||
def _test_metric(metric: Metric):
|
||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
||||
metric_val = metric(input1, input2)
|
||||
assert isinstance(metric_val, torch.Tensor)
|
||||
|
||||
if device is not None:
|
||||
assert metric.device in [device, torch.device(device)]
|
||||
assert metric_val.device in [device, torch.device(device)]
|
||||
|
||||
if dtype is not None:
|
||||
assert metric.dtype == dtype
|
||||
assert metric_val.dtype == dtype
|
||||
|
||||
devices = [None, 'cpu']
|
||||
if torch.cuda.is_available():
|
||||
devices += ['cuda:0']
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
change_and_check_device_dtype(device=device, dtype=dtype)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device('cuda', index=0)
|
||||
assert metric(input1, input2).device == torch.device('cuda', index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device('cpu')
|
||||
assert metric(input1, input2).device == torch.device('cpu')
|
||||
|
||||
metric.type(torch.int8)
|
||||
assert metric.dtype == torch.int8
|
||||
assert metric(input1, input2).dtype == torch.int8
|
||||
|
||||
metric.float()
|
||||
assert metric.dtype == torch.float32
|
||||
assert metric(input1, input2).dtype == torch.float32
|
||||
|
||||
metric.double()
|
||||
assert metric.dtype == torch.float64
|
||||
assert metric(input1, input2).dtype == torch.float64
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda()
|
||||
metric.half()
|
||||
assert metric.dtype == torch.float16
|
||||
assert metric(input1, input2).dtype == torch.float16
|
||||
|
||||
|
||||
def test_tensor_metric():
|
||||
_test_metric(DummyTensorMetric())
|
||||
|
||||
|
||||
def test_numpy_metric():
|
||||
_test_metric(DummyNumpyMetric())
|
|
@ -0,0 +1,66 @@
|
|||
import numbers
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def test_recursive_application_to_collection():
|
||||
ntc = namedtuple('Foo', ['bar'])
|
||||
|
||||
to_reduce = {
|
||||
'a': torch.tensor([1.]), # Tensor
|
||||
'b': [torch.tensor([2.])], # list
|
||||
'c': (torch.tensor([100.]),), # tuple
|
||||
'd': ntc(bar=5.), # named tuple
|
||||
'e': np.array([10.]), # numpy array
|
||||
'f': 'this_is_a_dummy_str', # string
|
||||
'g': 12. # number
|
||||
}
|
||||
|
||||
expected_result = {
|
||||
'a': torch.tensor([2.]),
|
||||
'b': [torch.tensor([4.])],
|
||||
'c': (torch.tensor([200.]),),
|
||||
'd': ntc(bar=torch.tensor([10.])),
|
||||
'e': np.array([20.]),
|
||||
'f': 'this_is_a_dummy_str',
|
||||
'g': 24.
|
||||
}
|
||||
|
||||
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray),
|
||||
lambda x: x * 2)
|
||||
|
||||
assert isinstance(reduced, dict), ' Type Consistency of dict not preserved'
|
||||
assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved'
|
||||
assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \
|
||||
'At least one type was not correctly preserved'
|
||||
|
||||
assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor'
|
||||
assert torch.allclose(expected_result['a'], reduced['a']), \
|
||||
'Reduction of a tensor does not yield the expected value'
|
||||
|
||||
assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list'
|
||||
assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \
|
||||
'At least one value of list reduction did not come out as expected'
|
||||
|
||||
assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple'
|
||||
assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \
|
||||
'At least one value of tuple reduction did not come out as expected'
|
||||
|
||||
assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given'
|
||||
assert isinstance(reduced['d'].bar, numbers.Number), \
|
||||
'Failure in type promotion while reducing fields of named tuples'
|
||||
assert reduced['d'].bar == expected_result['d'].bar
|
||||
|
||||
assert isinstance(reduced['e'], np.ndarray), 'Type Promotion in reduction of numpy arrays failed'
|
||||
assert reduced['e'] == expected_result['e'], \
|
||||
'Reduction of numpy array did not yield the expected result'
|
||||
|
||||
assert isinstance(reduced['f'], str), 'A string should not be reduced'
|
||||
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction'
|
||||
|
||||
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
|
||||
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
|
Loading…
Reference in New Issue