147 lines
5.3 KiB
Python
147 lines
5.3 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from pytorch_lightning.metrics.converters import (
|
|
tensor_metric, numpy_metric, tensor_collection_metric)
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
|
|
|
|
|
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
|
|
"""
|
|
Abstract base class for metric implementation.
|
|
|
|
Should be used to implement metrics that
|
|
1. Return multiple Outputs
|
|
2. Handle their own DDP sync
|
|
"""
|
|
|
|
def __init__(self, name: str):
|
|
"""
|
|
Args:
|
|
name: the metric's name
|
|
|
|
"""
|
|
super().__init__()
|
|
self.name = name
|
|
self._dtype = torch.get_default_dtype()
|
|
self._device = torch.device('cpu')
|
|
|
|
@abstractmethod
|
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
"""
|
|
Implements the actual metric computation.
|
|
|
|
Returns:
|
|
metric value
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class TensorMetric(Metric):
|
|
"""
|
|
Base class for metric implementation operating directly on tensors.
|
|
All inputs and outputs will be casted to tensors if necessary.
|
|
Already handles DDP sync and input/output conversions.
|
|
"""
|
|
|
|
def __init__(self, name: str,
|
|
reduce_group: Optional[Any] = None,
|
|
reduce_op: Optional[Any] = None):
|
|
"""
|
|
|
|
Args:
|
|
name: the metric's name
|
|
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
|
Defaults to all processes (world)
|
|
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
|
Defaults to sum.
|
|
"""
|
|
super().__init__(name)
|
|
self._orig_call = tensor_metric(group=reduce_group,
|
|
reduce_op=reduce_op)(super().__call__)
|
|
|
|
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
|
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
|
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
|
|
|
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
|
_to_device_dtype)
|
|
|
|
|
|
class TensorCollectionMetric(Metric):
|
|
"""
|
|
Base class for metric implementation operating directly on tensors.
|
|
All inputs will be casted to tensors if necessary. Outputs won't be casted.
|
|
Already handles DDP sync and input conversions.
|
|
|
|
This class differs from :class:`TensorMetric`, as it assumes all outputs to
|
|
be collections of tensors and does not explicitly convert them. This is
|
|
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
|
|
cannot be converted to tensors at the highest level.
|
|
All numpy arrays and numbers occuring in these outputs will still be converted.
|
|
|
|
Use this class as a baseclass, whenever you want to ensure inputs are
|
|
tensors and outputs cannot be converted to tensors automatically
|
|
|
|
"""
|
|
|
|
def __init__(self, name: str,
|
|
reduce_group: Optional[Any] = None,
|
|
reduce_op: Optional[Any] = None):
|
|
"""
|
|
|
|
Args:
|
|
name: the metric's name
|
|
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
|
Defaults to all processes (world)
|
|
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
|
Defaults to sum.
|
|
"""
|
|
super().__init__(name)
|
|
self._orig_call = tensor_collection_metric(group=reduce_group,
|
|
reduce_op=reduce_op)(super().__call__)
|
|
|
|
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
|
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
|
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
|
|
|
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
|
_to_device_dtype)
|
|
|
|
|
|
class NumpyMetric(Metric):
|
|
"""
|
|
Base class for metric implementation operating on numpy arrays.
|
|
All inputs will be casted to numpy if necessary and all outputs will
|
|
be casted to tensors if necessary.
|
|
Already handles DDP sync and input/output conversions.
|
|
"""
|
|
|
|
def __init__(self, name: str,
|
|
reduce_group: Optional[Any] = None,
|
|
reduce_op: Optional[Any] = None):
|
|
"""
|
|
|
|
Args:
|
|
name: the metric's name
|
|
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
|
Defaults to all processes (world)
|
|
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
|
Defaults to sum.
|
|
"""
|
|
super().__init__(name)
|
|
self._orig_call = numpy_metric(group=reduce_group,
|
|
reduce_op=reduce_op)(super().__call__)
|
|
|
|
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
|
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
|
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
|
|
|
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
|
_to_device_dtype)
|