Add persistent flag to Metric.add_state (#4195)
* add persistant flag to add_state in metrics * wrap register_buffer with try catch * pep8 * use loose version * test * pep8
This commit is contained in:
parent
3fe479f348
commit
827a557269
|
@ -17,6 +17,7 @@ from typing import Any, Callable, Optional, Union
|
|||
from collections.abc import Mapping, Sequence
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
@ -78,7 +79,9 @@ class Metric(nn.Module, ABC):
|
|||
self._reductions = {}
|
||||
self._defaults = {}
|
||||
|
||||
def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
|
||||
def add_state(
|
||||
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
|
||||
):
|
||||
"""
|
||||
Adds metric state variable. Only used by subclasses.
|
||||
|
||||
|
@ -90,6 +93,7 @@ class Metric(nn.Module, ABC):
|
|||
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
|
||||
function in this parameter.
|
||||
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
|
||||
|
||||
Note:
|
||||
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
|
||||
|
@ -130,7 +134,11 @@ class Metric(nn.Module, ABC):
|
|||
)
|
||||
|
||||
if isinstance(default, torch.Tensor):
|
||||
self.register_buffer(name, default)
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
# persistent keyword is only supported in torch >= 1.6.0
|
||||
self.register_buffer(name, default, persistent=persistent)
|
||||
else:
|
||||
self.register_buffer(name, default)
|
||||
else:
|
||||
setattr(self, name, default)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pickle
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -59,6 +60,19 @@ def test_add_state():
|
|||
assert a._reductions["e"](torch.tensor([1, 1])) == -1
|
||||
|
||||
|
||||
def test_add_state_persistent():
|
||||
a = Dummy()
|
||||
|
||||
a.add_state("a", torch.tensor(0), "sum", persistent=True)
|
||||
assert "a" in a.state_dict()
|
||||
|
||||
a.add_state("b", torch.tensor(0), "sum", persistent=False)
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
assert "b" not in a.state_dict()
|
||||
|
||||
|
||||
|
||||
def test_reset():
|
||||
class A(Dummy):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue