lightning/pytorch_lightning/metrics/metric.py

147 lines
5.3 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
import torch.distributed
from pytorch_lightning.metrics.converters import (
tensor_metric, numpy_metric, tensor_collection_metric)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
"""
Abstract base class for metric implementation.
Should be used to implement metrics that
1. Return multiple Outputs
2. Handle their own DDP sync
"""
def __init__(self, name: str):
"""
Args:
name: the metric's name
"""
super().__init__()
self.name = name
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')
@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Implements the actual metric computation.
Returns:
metric value
"""
raise NotImplementedError
class TensorMetric(Metric):
"""
Base class for metric implementation operating directly on tensors.
All inputs and outputs will be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self._orig_call = tensor_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
class TensorCollectionMetric(Metric):
"""
Base class for metric implementation operating directly on tensors.
All inputs will be casted to tensors if necessary. Outputs won't be casted.
Already handles DDP sync and input conversions.
This class differs from :class:`TensorMetric`, as it assumes all outputs to
be collections of tensors and does not explicitly convert them. This is
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
cannot be converted to tensors at the highest level.
All numpy arrays and numbers occuring in these outputs will still be converted.
Use this class as a baseclass, whenever you want to ensure inputs are
tensors and outputs cannot be converted to tensors automatically
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self._orig_call = tensor_collection_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
class NumpyMetric(Metric):
"""
Base class for metric implementation operating on numpy arrays.
All inputs will be casted to numpy if necessary and all outputs will
be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self._orig_call = numpy_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)