From 827a557269adedbe2594a820b13c9be100808b51 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Fri, 16 Oct 2020 14:36:03 -0400 Subject: [PATCH] Add persistent flag to Metric.add_state (#4195) * add persistant flag to add_state in metrics * wrap register_buffer with try catch * pep8 * use loose version * test * pep8 --- pytorch_lightning/metrics/metric.py | 12 ++++++++++-- tests/metrics/test_metric.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 8911341466..a6da405730 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Optional, Union from collections.abc import Mapping, Sequence from collections import namedtuple from copy import deepcopy +from distutils.version import LooseVersion import os import torch @@ -78,7 +79,9 @@ class Metric(nn.Module, ABC): self._reductions = {} self._defaults = {} - def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None): + def add_state( + self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True + ): """ Adds metric state variable. Only used by subclasses. @@ -90,6 +93,7 @@ class Metric(nn.Module, ABC): If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom function in this parameter. + persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Note: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. @@ -130,7 +134,11 @@ class Metric(nn.Module, ABC): ) if isinstance(default, torch.Tensor): - self.register_buffer(name, default) + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + # persistent keyword is only supported in torch >= 1.6.0 + self.register_buffer(name, default, persistent=persistent) + else: + self.register_buffer(name, default) else: setattr(self, name, default) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f320e26445..ccb9b4ad0a 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -1,5 +1,6 @@ import pickle +from distutils.version import LooseVersion import cloudpickle import numpy as np import pytest @@ -59,6 +60,19 @@ def test_add_state(): assert a._reductions["e"](torch.tensor([1, 1])) == -1 +def test_add_state_persistent(): + a = Dummy() + + a.add_state("a", torch.tensor(0), "sum", persistent=True) + assert "a" in a.state_dict() + + a.add_state("b", torch.tensor(0), "sum", persistent=False) + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + assert "b" not in a.state_dict() + + + def test_reset(): class A(Dummy): pass