lightning/tests/metrics/test_metric.py

181 lines
3.8 KiB
Python
Raw Normal View History

import pickle
from collections import OrderedDict
from distutils.version import LooseVersion
import cloudpickle
import numpy as np
import pytest
import torch
from pytorch_lightning.metrics.metric import Metric
torch.manual_seed(42)
class Dummy(Metric):
name = "Dummy"
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
def test_inherit():
a = Dummy()
def test_add_state():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum")
assert a._reductions["a"](torch.tensor([1, 1])) == 2
a.add_state("b", torch.tensor(0), "mean")
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
with pytest.raises(ValueError):
a.add_state("d2", torch.tensor(0), 42)
with pytest.raises(ValueError):
a.add_state("d3", [torch.tensor(0)], 'sum')
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
def custom_fx(x):
return -1
a.add_state("e", torch.tensor(0), custom_fx)
assert a._reductions["e"](torch.tensor([1, 1])) == -1
def test_add_state_persistent():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()
a.add_state("b", torch.tensor(0), "sum", persistent=False)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
assert "b" not in a.state_dict()
def test_reset():
class A(Dummy):
pass
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0
def test_update():
class A(Dummy):
def update(self, x):
self.x += x
a = A()
assert a.x == 0
assert a._computed is None
a.update(1)
assert a._computed is None
assert a.x == 1
a.update(2)
assert a.x == 3
assert a._computed is None
def test_compute():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert 0 == a.compute()
assert 0 == a.x
a.update(1)
assert a._computed is None
assert a.compute() == 1
assert a._computed == 1
a.update(2)
assert a._computed is None
assert a.compute() == 2
assert a._computed == 2
# called without update, should return cached value
a._computed = 5
assert a.compute() == 5
def test_forward():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert a(5) == 5
assert a._forward_cache == 5
assert a(8) == 8
assert a._forward_cache == 8
assert a.compute() == 13
class ToPickle(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
def test_pickle(tmpdir):
# doesn't tests for DDP
a = ToPickle()
a.update(1)
metric_pickled = pickle.dumps(a)
metric_loaded = pickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
metric_loaded.update(5)
assert metric_loaded.compute() == 5
metric_pickled = cloudpickle.dumps(a)
metric_loaded = cloudpickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
def test_state_dict(tmpdir):
""" test that metric states can be removed and added to state dict """
metric = Dummy()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()