Fix metric state reset (#5273)
* Fix metric state reset
* Fix test
* Improve formatting
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
(cherry picked from commit 4913cbb987
)
This commit is contained in:
parent
f790d30d7e
commit
4292fe0532
|
@ -94,7 +94,8 @@ class Metric(nn.Module, ABC):
|
|||
reset to this value when ``self.reset()`` is called.
|
||||
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
|
||||
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
|
||||
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
|
||||
function in this parameter.
|
||||
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
|
||||
Default is ``False``.
|
||||
|
@ -244,7 +245,7 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
for attr, default in self._defaults.items():
|
||||
current_val = getattr(self, attr)
|
||||
if isinstance(current_val, torch.Tensor):
|
||||
if isinstance(default, torch.Tensor):
|
||||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
|
|
@ -26,6 +26,20 @@ class Dummy(Metric):
|
|||
pass
|
||||
|
||||
|
||||
class DummyList(Metric):
|
||||
name = "DummyList"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", list(), dist_reduce_fx=None)
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def compute(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_inherit():
|
||||
Dummy()
|
||||
|
||||
|
@ -77,12 +91,21 @@ def test_reset():
|
|||
class A(Dummy):
|
||||
pass
|
||||
|
||||
class B(DummyList):
|
||||
pass
|
||||
|
||||
a = A()
|
||||
assert a.x == 0
|
||||
a.x = torch.tensor(5)
|
||||
a.reset()
|
||||
assert a.x == 0
|
||||
|
||||
b = B()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
b.x = torch.tensor(5)
|
||||
b.reset()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
|
||||
|
||||
def test_update():
|
||||
class A(Dummy):
|
||||
|
|
Loading…
Reference in New Issue