395 lines
9.7 KiB
Python
395 lines
9.7 KiB
Python
import pickle
|
|
from collections import OrderedDict
|
|
from distutils.version import LooseVersion
|
|
|
|
import cloudpickle
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
from pytorch_lightning.metrics.metric import Metric, MetricCollection
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
class Dummy(Metric):
|
|
name = "Dummy"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
|
|
|
|
def update(self):
|
|
pass
|
|
|
|
def compute(self):
|
|
pass
|
|
|
|
|
|
class DummyList(Metric):
|
|
name = "DummyList"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state("x", list(), dist_reduce_fx=None)
|
|
|
|
def update(self):
|
|
pass
|
|
|
|
def compute(self):
|
|
pass
|
|
|
|
|
|
def test_inherit():
|
|
Dummy()
|
|
|
|
|
|
def test_add_state():
|
|
a = Dummy()
|
|
|
|
a.add_state("a", torch.tensor(0), "sum")
|
|
assert a._reductions["a"](torch.tensor([1, 1])) == 2
|
|
|
|
a.add_state("b", torch.tensor(0), "mean")
|
|
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
|
|
|
|
a.add_state("c", torch.tensor(0), "cat")
|
|
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, )
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d1", torch.tensor(0), 'xyz')
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d2", torch.tensor(0), 42)
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d3", [torch.tensor(0)], 'sum')
|
|
|
|
with pytest.raises(ValueError):
|
|
a.add_state("d4", 42, 'sum')
|
|
|
|
def custom_fx(x):
|
|
return -1
|
|
|
|
a.add_state("e", torch.tensor(0), custom_fx)
|
|
assert a._reductions["e"](torch.tensor([1, 1])) == -1
|
|
|
|
|
|
def test_add_state_persistent():
|
|
a = Dummy()
|
|
|
|
a.add_state("a", torch.tensor(0), "sum", persistent=True)
|
|
assert "a" in a.state_dict()
|
|
|
|
a.add_state("b", torch.tensor(0), "sum", persistent=False)
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
assert "b" not in a.state_dict()
|
|
|
|
|
|
def test_reset():
|
|
|
|
class A(Dummy):
|
|
pass
|
|
|
|
class B(DummyList):
|
|
pass
|
|
|
|
a = A()
|
|
assert a.x == 0
|
|
a.x = torch.tensor(5)
|
|
a.reset()
|
|
assert a.x == 0
|
|
|
|
b = B()
|
|
assert isinstance(b.x, list) and len(b.x) == 0
|
|
b.x = torch.tensor(5)
|
|
b.reset()
|
|
assert isinstance(b.x, list) and len(b.x) == 0
|
|
|
|
|
|
def test_update():
|
|
|
|
class A(Dummy):
|
|
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
a = A()
|
|
assert a.x == 0
|
|
assert a._computed is None
|
|
a.update(1)
|
|
assert a._computed is None
|
|
assert a.x == 1
|
|
a.update(2)
|
|
assert a.x == 3
|
|
assert a._computed is None
|
|
|
|
|
|
def test_compute():
|
|
|
|
class A(Dummy):
|
|
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
a = A()
|
|
assert 0 == a.compute()
|
|
assert 0 == a.x
|
|
a.update(1)
|
|
assert a._computed is None
|
|
assert a.compute() == 1
|
|
assert a._computed == 1
|
|
a.update(2)
|
|
assert a._computed is None
|
|
assert a.compute() == 3
|
|
assert a._computed == 3
|
|
|
|
# called without update, should return cached value
|
|
a._computed = 5
|
|
assert a.compute() == 5
|
|
|
|
|
|
def test_hash():
|
|
|
|
class A(Dummy):
|
|
pass
|
|
|
|
class B(DummyList):
|
|
pass
|
|
|
|
a1 = A()
|
|
a2 = A()
|
|
assert hash(a1) != hash(a2)
|
|
|
|
b1 = B()
|
|
b2 = B()
|
|
assert hash(b1) == hash(b2)
|
|
assert isinstance(b1.x, list) and len(b1.x) == 0
|
|
b1.x.append(torch.tensor(5))
|
|
assert isinstance(hash(b1), int) # <- check that nothing crashes
|
|
assert isinstance(b1.x, list) and len(b1.x) == 1
|
|
b2.x.append(torch.tensor(5))
|
|
# Sanity:
|
|
assert isinstance(b2.x, list) and len(b2.x) == 1
|
|
# Now that they have tensor contents, they should have different hashes:
|
|
assert hash(b1) != hash(b2)
|
|
|
|
|
|
def test_forward():
|
|
|
|
class A(Dummy):
|
|
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
a = A()
|
|
assert a(5) == 5
|
|
assert a._forward_cache == 5
|
|
|
|
assert a(8) == 8
|
|
assert a._forward_cache == 8
|
|
|
|
assert a.compute() == 13
|
|
|
|
|
|
class DummyMetric1(Dummy):
|
|
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
|
|
class DummyMetric2(Dummy):
|
|
|
|
def update(self, y):
|
|
self.x -= y
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
|
|
def test_pickle(tmpdir):
|
|
# doesn't tests for DDP
|
|
a = DummyMetric1()
|
|
a.update(1)
|
|
|
|
metric_pickled = pickle.dumps(a)
|
|
metric_loaded = pickle.loads(metric_pickled)
|
|
|
|
assert metric_loaded.compute() == 1
|
|
|
|
metric_loaded.update(5)
|
|
assert metric_loaded.compute() == 6
|
|
|
|
metric_pickled = cloudpickle.dumps(a)
|
|
metric_loaded = cloudpickle.loads(metric_pickled)
|
|
|
|
assert metric_loaded.compute() == 1
|
|
|
|
|
|
def test_state_dict(tmpdir):
|
|
""" test that metric states can be removed and added to state dict """
|
|
metric = Dummy()
|
|
assert metric.state_dict() == OrderedDict()
|
|
metric.persistent(True)
|
|
assert metric.state_dict() == OrderedDict(x=0)
|
|
metric.persistent(False)
|
|
assert metric.state_dict() == OrderedDict()
|
|
|
|
|
|
def test_child_metric_state_dict():
|
|
""" test that child metric states will be added to parent state dict """
|
|
|
|
class TestModule(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.metric = Dummy()
|
|
self.metric.add_state('a', torch.tensor(0), persistent=True)
|
|
self.metric.add_state('b', [], persistent=True)
|
|
self.metric.register_buffer('c', torch.tensor(0))
|
|
|
|
module = TestModule()
|
|
expected_state_dict = {
|
|
'metric.a': torch.tensor(0),
|
|
'metric.b': [],
|
|
'metric.c': torch.tensor(0),
|
|
}
|
|
assert module.state_dict() == expected_state_dict
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
|
def test_device_and_dtype_transfer(tmpdir):
|
|
metric = DummyMetric1()
|
|
assert metric.x.is_cuda is False
|
|
assert metric.x.dtype == torch.float32
|
|
|
|
metric = metric.to(device='cuda')
|
|
assert metric.x.is_cuda
|
|
|
|
metric = metric.double()
|
|
assert metric.x.dtype == torch.float64
|
|
|
|
metric = metric.half()
|
|
assert metric.x.dtype == torch.float16
|
|
|
|
|
|
def test_metric_collection(tmpdir):
|
|
m1 = DummyMetric1()
|
|
m2 = DummyMetric2()
|
|
|
|
metric_collection = MetricCollection([m1, m2])
|
|
|
|
# Test correct dict structure
|
|
assert len(metric_collection) == 2
|
|
assert metric_collection['DummyMetric1'] == m1
|
|
assert metric_collection['DummyMetric2'] == m2
|
|
|
|
# Test correct initialization
|
|
for name, metric in metric_collection.items():
|
|
assert metric.x == 0, f'Metric {name} not initialized correctly'
|
|
|
|
# Test every metric gets updated
|
|
metric_collection.update(5)
|
|
for name, metric in metric_collection.items():
|
|
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
|
|
|
|
# Test compute on each metric
|
|
metric_collection.update(-5)
|
|
metric_vals = metric_collection.compute()
|
|
assert len(metric_vals) == 2
|
|
for name, metric_val in metric_vals.items():
|
|
assert metric_val == 0, f'Metric {name}.compute not called correctly'
|
|
|
|
# Test that everything is reset
|
|
for name, metric in metric_collection.items():
|
|
assert metric.x == 0, f'Metric {name} not reset correctly'
|
|
|
|
# Test pickable
|
|
metric_pickled = pickle.dumps(metric_collection)
|
|
metric_loaded = pickle.loads(metric_pickled)
|
|
assert isinstance(metric_loaded, MetricCollection)
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
|
def test_device_and_dtype_transfer_metriccollection(tmpdir):
|
|
m1 = DummyMetric1()
|
|
m2 = DummyMetric2()
|
|
|
|
metric_collection = MetricCollection([m1, m2])
|
|
for _, metric in metric_collection.items():
|
|
assert metric.x.is_cuda is False
|
|
assert metric.x.dtype == torch.float32
|
|
|
|
metric_collection = metric_collection.to(device='cuda')
|
|
for _, metric in metric_collection.items():
|
|
assert metric.x.is_cuda
|
|
|
|
metric_collection = metric_collection.double()
|
|
for _, metric in metric_collection.items():
|
|
assert metric.x.dtype == torch.float64
|
|
|
|
metric_collection = metric_collection.half()
|
|
for _, metric in metric_collection.items():
|
|
assert metric.x.dtype == torch.float16
|
|
|
|
|
|
def test_metric_collection_wrong_input(tmpdir):
|
|
""" Check that errors are raised on wrong input """
|
|
m1 = DummyMetric1()
|
|
|
|
# Not all input are metrics (list)
|
|
with pytest.raises(ValueError):
|
|
_ = MetricCollection([m1, 5])
|
|
|
|
# Not all input are metrics (dict)
|
|
with pytest.raises(ValueError):
|
|
_ = MetricCollection({'metric1': m1, 'metric2': 5})
|
|
|
|
# Same metric passed in multiple times
|
|
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
|
|
_ = MetricCollection([m1, m1])
|
|
|
|
# Not a list or dict passed in
|
|
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
|
|
_ = MetricCollection(m1)
|
|
|
|
|
|
def test_metric_collection_args_kwargs(tmpdir):
|
|
""" Check that args and kwargs gets passed correctly in metric collection,
|
|
Checks both update and forward method
|
|
"""
|
|
m1 = DummyMetric1()
|
|
m2 = DummyMetric2()
|
|
|
|
metric_collection = MetricCollection([m1, m2])
|
|
|
|
# args gets passed to all metrics
|
|
metric_collection.update(5)
|
|
assert metric_collection['DummyMetric1'].x == 5
|
|
assert metric_collection['DummyMetric2'].x == -5
|
|
metric_collection.reset()
|
|
_ = metric_collection(5)
|
|
assert metric_collection['DummyMetric1'].x == 5
|
|
assert metric_collection['DummyMetric2'].x == -5
|
|
metric_collection.reset()
|
|
|
|
# kwargs gets only passed to metrics that it matches
|
|
metric_collection.update(x=10, y=20)
|
|
assert metric_collection['DummyMetric1'].x == 10
|
|
assert metric_collection['DummyMetric2'].x == -20
|
|
metric_collection.reset()
|
|
_ = metric_collection(x=10, y=20)
|
|
assert metric_collection['DummyMetric1'].x == 10
|
|
assert metric_collection['DummyMetric2'].x == -20
|