lightning/tests/metrics/test_metric.py

338 lines
8.5 KiB
Python
Raw Normal View History

import pickle
from collections import OrderedDict
from distutils.version import LooseVersion
import cloudpickle
import numpy as np
import pytest
import torch
from pytorch_lightning.metrics.metric import Metric, MetricCollection
torch.manual_seed(42)
class Dummy(Metric):
name = "Dummy"
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
class DummyList(Metric):
name = "DummyList"
def __init__(self):
super().__init__()
self.add_state("x", list(), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
def test_inherit():
2020-12-21 05:40:55 +00:00
Dummy()
def test_add_state():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum")
assert a._reductions["a"](torch.tensor([1, 1])) == 2
a.add_state("b", torch.tensor(0), "mean")
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
with pytest.raises(ValueError):
a.add_state("d2", torch.tensor(0), 42)
with pytest.raises(ValueError):
a.add_state("d3", [torch.tensor(0)], 'sum')
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
def custom_fx(x):
return -1
a.add_state("e", torch.tensor(0), custom_fx)
assert a._reductions["e"](torch.tensor([1, 1])) == -1
def test_add_state_persistent():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()
a.add_state("b", torch.tensor(0), "sum", persistent=False)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
assert "b" not in a.state_dict()
def test_reset():
class A(Dummy):
pass
class B(DummyList):
pass
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0
b = B()
assert isinstance(b.x, list) and len(b.x) == 0
b.x = torch.tensor(5)
b.reset()
assert isinstance(b.x, list) and len(b.x) == 0
def test_update():
class A(Dummy):
def update(self, x):
self.x += x
a = A()
assert a.x == 0
assert a._computed is None
a.update(1)
assert a._computed is None
assert a.x == 1
a.update(2)
assert a.x == 3
assert a._computed is None
def test_compute():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert 0 == a.compute()
assert 0 == a.x
a.update(1)
assert a._computed is None
assert a.compute() == 1
assert a._computed == 1
a.update(2)
assert a._computed is None
assert a.compute() == 3
assert a._computed == 3
# called without update, should return cached value
a._computed = 5
assert a.compute() == 5
def test_forward():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert a(5) == 5
assert a._forward_cache == 5
assert a(8) == 8
assert a._forward_cache == 8
assert a.compute() == 13
class DummyMetric1(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
class DummyMetric2(Dummy):
def update(self, y):
self.x -= y
def compute(self):
return self.x
def test_pickle(tmpdir):
# doesn't tests for DDP
a = DummyMetric1()
a.update(1)
metric_pickled = pickle.dumps(a)
metric_loaded = pickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
metric_loaded.update(5)
assert metric_loaded.compute() == 6
metric_pickled = cloudpickle.dumps(a)
metric_loaded = cloudpickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
def test_state_dict(tmpdir):
""" test that metric states can be removed and added to state dict """
metric = Dummy()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer(tmpdir):
metric = DummyMetric1()
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric = metric.to(device='cuda')
assert metric.x.is_cuda
metric = metric.double()
assert metric.x.dtype == torch.float64
metric = metric.half()
assert metric.x.dtype == torch.float16
def test_metric_collection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# Test correct dict structure
assert len(metric_collection) == 2
assert metric_collection['DummyMetric1'] == m1
assert metric_collection['DummyMetric2'] == m2
# Test correct initialization
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not initialized correctly'
# Test every metric gets updated
metric_collection.update(5)
for name, metric in metric_collection.items():
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
# Test compute on each metric
metric_collection.update(-5)
metric_vals = metric_collection.compute()
assert len(metric_vals) == 2
for name, metric_val in metric_vals.items():
assert metric_val == 0, f'Metric {name}.compute not called correctly'
# Test that everything is reset
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not reset correctly'
# Test pickable
metric_pickled = pickle.dumps(metric_collection)
metric_loaded = pickle.loads(metric_pickled)
assert isinstance(metric_loaded, MetricCollection)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer_metriccollection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
for _, metric in metric_collection.items():
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric_collection = metric_collection.to(device='cuda')
for _, metric in metric_collection.items():
assert metric.x.is_cuda
metric_collection = metric_collection.double()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float64
metric_collection = metric_collection.half()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float16
def test_metric_collection_wrong_input(tmpdir):
""" Check that errors are raised on wrong input """
m1 = DummyMetric1()
# Not all input are metrics (list)
with pytest.raises(ValueError):
_ = MetricCollection([m1, 5])
# Not all input are metrics (dict)
with pytest.raises(ValueError):
_ = MetricCollection({'metric1': m1,
'metric2': 5})
# Same metric passed in multiple times
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
_ = MetricCollection([m1, m1])
# Not a list or dict passed in
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
_ = MetricCollection(m1)
def test_metric_collection_args_kwargs(tmpdir):
""" Check that args and kwargs gets passed correctly in metric collection,
Checks both update and forward method
"""
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# args gets passed to all metrics
metric_collection.update(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
_ = metric_collection(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
# kwargs gets only passed to metrics that it matches
metric_collection.update(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20
metric_collection.reset()
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20