93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
from typing import Callable, Union
|
|
|
|
import torch
|
|
|
|
from pytorch_lightning.metrics.metric import Metric
|
|
|
|
|
|
class CompositionalMetric(Metric):
|
|
"""Composition of two metrics with a specific operator
|
|
which will be executed upon metric's compute
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
operator: Callable,
|
|
metric_a: Union[Metric, int, float, torch.Tensor],
|
|
metric_b: Union[Metric, int, float, torch.Tensor, None],
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
operator: the operator taking in one (if metric_b is None)
|
|
or two arguments. Will be applied to outputs of metric_a.compute()
|
|
and (optionally if metric_b is not None) metric_b.compute()
|
|
metric_a: first metric whose compute() result is the first argument of operator
|
|
metric_b: second metric whose compute() result is the second argument of operator.
|
|
For operators taking in only one input, this should be None
|
|
"""
|
|
super().__init__()
|
|
|
|
self.op = operator
|
|
|
|
if isinstance(metric_a, torch.Tensor):
|
|
self.register_buffer("metric_a", metric_a)
|
|
else:
|
|
self.metric_a = metric_a
|
|
|
|
if isinstance(metric_b, torch.Tensor):
|
|
self.register_buffer("metric_b", metric_b)
|
|
else:
|
|
self.metric_b = metric_b
|
|
|
|
def _sync_dist(self, dist_sync_fn=None):
|
|
# No syncing required here. syncing will be done in metric_a and metric_b
|
|
pass
|
|
|
|
def update(self, *args, **kwargs):
|
|
if isinstance(self.metric_a, Metric):
|
|
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))
|
|
|
|
if isinstance(self.metric_b, Metric):
|
|
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))
|
|
|
|
def compute(self):
|
|
|
|
# also some parsing for kwargs?
|
|
if isinstance(self.metric_a, Metric):
|
|
val_a = self.metric_a.compute()
|
|
else:
|
|
val_a = self.metric_a
|
|
|
|
if isinstance(self.metric_b, Metric):
|
|
val_b = self.metric_b.compute()
|
|
else:
|
|
val_b = self.metric_b
|
|
|
|
if val_b is None:
|
|
return self.op(val_a)
|
|
|
|
return self.op(val_a, val_b)
|
|
|
|
def reset(self):
|
|
if isinstance(self.metric_a, Metric):
|
|
self.metric_a.reset()
|
|
|
|
if isinstance(self.metric_b, Metric):
|
|
self.metric_b.reset()
|
|
|
|
def persistent(self, mode: bool = False):
|
|
if isinstance(self.metric_a, Metric):
|
|
self.metric_a.persistent(mode=mode)
|
|
if isinstance(self.metric_b, Metric):
|
|
self.metric_b.persistent(mode=mode)
|
|
|
|
def __repr__(self):
|
|
repr_str = (
|
|
self.__class__.__name__
|
|
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
|
|
)
|
|
|
|
return repr_str
|