From 4292fe0532b9b5f68af8a48d8b844585a3b59f86 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 29 Dec 2020 22:09:10 +0100 Subject: [PATCH] Fix metric state reset (#5273) * Fix metric state reset * Fix test * Improve formatting Co-authored-by: Ananya Harsh Jha (cherry picked from commit 4913cbb987a0516f8b33c016134b19c0588d107a) --- pytorch_lightning/metrics/metric.py | 5 +++-- tests/metrics/test_metric.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 0f61b94c55..a21242c3bd 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -94,7 +94,8 @@ class Metric(nn.Module, ABC): reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom + and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction + only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Default is ``False``. @@ -244,7 +245,7 @@ class Metric(nn.Module, ABC): """ for attr, default in self._defaults.items(): current_val = getattr(self, attr) - if isinstance(current_val, torch.Tensor): + if isinstance(default, torch.Tensor): setattr(self, attr, deepcopy(default).to(current_val.device)) else: setattr(self, attr, deepcopy(default)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index d1c8b8c441..33948204cb 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -26,6 +26,20 @@ class Dummy(Metric): pass +class DummyList(Metric): + name = "DummyList" + + def __init__(self): + super().__init__() + self.add_state("x", list(), dist_reduce_fx=None) + + def update(self): + pass + + def compute(self): + pass + + def test_inherit(): Dummy() @@ -77,12 +91,21 @@ def test_reset(): class A(Dummy): pass + class B(DummyList): + pass + a = A() assert a.x == 0 a.x = torch.tensor(5) a.reset() assert a.x == 0 + b = B() + assert isinstance(b.x, list) and len(b.x) == 0 + b.x = torch.tensor(5) + b.reset() + assert isinstance(b.x, list) and len(b.x) == 0 + def test_update(): class A(Dummy):