Add Metric <-> Lightning Module integration tests (#4008)
* lightning module metric tests * whitespace * pep8
This commit is contained in:
parent
bfdea3ea28
commit
7db26a913b
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
from tests.base.boring_model import BoringModel
|
||||||
|
from pytorch_lightning.metrics import Metric
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class SumMetric(Metric):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||||
|
|
||||||
|
def update(self, x):
|
||||||
|
self.x += x
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_lightning(tmpdir):
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.metric = SumMetric()
|
||||||
|
self.sum = 0.0
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x = batch
|
||||||
|
self.metric(x.sum())
|
||||||
|
self.sum += x.sum()
|
||||||
|
|
||||||
|
return self.step(x)
|
||||||
|
|
||||||
|
def training_epoch_end(self, outs):
|
||||||
|
assert torch.allclose(self.sum, self.metric.compute())
|
||||||
|
self.sum = 0.0
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=2,
|
||||||
|
limit_val_batches=2,
|
||||||
|
max_epochs=2,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_lightning_log(tmpdir):
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.metric = SumMetric()
|
||||||
|
self.sum = 0.0
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x = batch
|
||||||
|
self.metric(x.sum())
|
||||||
|
self.sum += x.sum()
|
||||||
|
self.log("sum", self.metric, on_epoch=True, on_step=False)
|
||||||
|
return self.step(x)
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=2,
|
||||||
|
limit_val_batches=2,
|
||||||
|
max_epochs=1,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
logged = trainer.logged_metrics
|
||||||
|
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
|
Loading…
Reference in New Issue