lightning/tests/metrics/test_metric.py

157 lines
3.0 KiB
Python
Raw Normal View History

import pytest
import torch
from pytorch_lightning.metrics.metric import Metric
import os
import numpy as np
import pickle
import cloudpickle
torch.manual_seed(42)
class Dummy(Metric):
name = "Dummy"
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
def test_inherit():
a = Dummy()
def test_add_state():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum")
assert a._reductions["a"](torch.tensor([1, 1])) == 2
a.add_state("b", torch.tensor(0), "mean")
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
with pytest.raises(ValueError):
a.add_state("d2", torch.tensor(0), 42)
with pytest.raises(ValueError):
a.add_state("d3", [torch.tensor(0)], 'sum')
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
def custom_fx(x):
return -1
a.add_state("e", torch.tensor(0), custom_fx)
assert a._reductions["e"](torch.tensor([1, 1])) == -1
def test_reset():
class A(Dummy):
pass
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0
def test_update():
class A(Dummy):
def update(self, x):
self.x += x
a = A()
assert a.x == 0
assert a._computed is None
a.update(1)
assert a._computed is None
assert a.x == 1
a.update(2)
assert a.x == 3
assert a._computed is None
def test_compute():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert 0 == a.compute()
assert 0 == a.x
a.update(1)
assert a._computed is None
assert a.compute() == 1
assert a._computed == 1
a.update(2)
assert a._computed is None
assert a.compute() == 2
assert a._computed == 2
# called without update, should return cached value
a._computed = 5
assert a.compute() == 5
def test_forward():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert a(5) == 5
assert a._forward_cache == 5
assert a(8) == 8
assert a._forward_cache == 8
assert a.compute() == 13
class ToPickle(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
def test_pickle(tmpdir):
# doesn't tests for DDP
a = ToPickle()
a.update(1)
metric_pickled = pickle.dumps(a)
metric_loaded = pickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
metric_loaded.update(5)
assert metric_loaded.compute() == 5
metric_pickled = cloudpickle.dumps(a)
metric_loaded = cloudpickle.loads(metric_pickled)
assert metric_loaded.compute() == 1