170 lines
3.4 KiB
Python
170 lines
3.4 KiB
Python
import pickle
|
|
|
|
from distutils.version import LooseVersion
|
|
import cloudpickle
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from pytorch_lightning.metrics.metric import Metric
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
class Dummy(Metric):
|
|
name = "Dummy"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
|
|
|
|
def update(self):
|
|
pass
|
|
|
|
def compute(self):
|
|
pass
|
|
|
|
|
|
def test_inherit():
|
|
a = Dummy()
|
|
|
|
|
|
def test_add_state():
|
|
a = Dummy()
|
|
|
|
a.add_state("a", torch.tensor(0), "sum")
|
|
assert a._reductions["a"](torch.tensor([1, 1])) == 2
|
|
|
|
a.add_state("b", torch.tensor(0), "mean")
|
|
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
|
|
|
|
a.add_state("c", torch.tensor(0), "cat")
|
|
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d1", torch.tensor(0), 'xyz')
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d2", torch.tensor(0), 42)
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d3", [torch.tensor(0)], 'sum')
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d4", 42, 'sum')
|
|
|
|
def custom_fx(x):
|
|
return -1
|
|
|
|
a.add_state("e", torch.tensor(0), custom_fx)
|
|
assert a._reductions["e"](torch.tensor([1, 1])) == -1
|
|
|
|
|
|
def test_add_state_persistent():
|
|
a = Dummy()
|
|
|
|
a.add_state("a", torch.tensor(0), "sum", persistent=True)
|
|
assert "a" in a.state_dict()
|
|
|
|
a.add_state("b", torch.tensor(0), "sum", persistent=False)
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
assert "b" not in a.state_dict()
|
|
|
|
|
|
def test_reset():
|
|
class A(Dummy):
|
|
pass
|
|
|
|
a = A()
|
|
assert a.x == 0
|
|
a.x = torch.tensor(5)
|
|
a.reset()
|
|
assert a.x == 0
|
|
|
|
|
|
def test_update():
|
|
class A(Dummy):
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
a = A()
|
|
assert a.x == 0
|
|
assert a._computed is None
|
|
a.update(1)
|
|
assert a._computed is None
|
|
assert a.x == 1
|
|
a.update(2)
|
|
assert a.x == 3
|
|
assert a._computed is None
|
|
|
|
|
|
def test_compute():
|
|
class A(Dummy):
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
a = A()
|
|
assert 0 == a.compute()
|
|
assert 0 == a.x
|
|
a.update(1)
|
|
assert a._computed is None
|
|
assert a.compute() == 1
|
|
assert a._computed == 1
|
|
a.update(2)
|
|
assert a._computed is None
|
|
assert a.compute() == 2
|
|
assert a._computed == 2
|
|
|
|
# called without update, should return cached value
|
|
a._computed = 5
|
|
assert a.compute() == 5
|
|
|
|
|
|
def test_forward():
|
|
class A(Dummy):
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
a = A()
|
|
assert a(5) == 5
|
|
assert a._forward_cache == 5
|
|
|
|
assert a(8) == 8
|
|
assert a._forward_cache == 8
|
|
|
|
assert a.compute() == 13
|
|
|
|
|
|
class ToPickle(Dummy):
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
|
|
def test_pickle(tmpdir):
|
|
# doesn't tests for DDP
|
|
a = ToPickle()
|
|
a.update(1)
|
|
|
|
metric_pickled = pickle.dumps(a)
|
|
metric_loaded = pickle.loads(metric_pickled)
|
|
|
|
assert metric_loaded.compute() == 1
|
|
|
|
metric_loaded.update(5)
|
|
assert metric_loaded.compute() == 5
|
|
|
|
metric_pickled = cloudpickle.dumps(a)
|
|
metric_loaded = cloudpickle.loads(metric_pickled)
|
|
|
|
assert metric_loaded.compute() == 1
|