diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index a35562327d..1622217c24 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -53,6 +53,7 @@ def test_metric_lightning(tmpdir): def test_metric_lightning_log(tmpdir): + """ Test logging a metric object and that the metric state gets reset after each epoch.""" class TestModel(BoringModel): def __init__(self): super().__init__() @@ -60,6 +61,9 @@ def test_metric_lightning_log(tmpdir): self.metric_epoch = SumMetric() self.sum = 0.0 + def on_epoch_start(self): + self.sum = 0.0 + def training_step(self, batch, batch_idx): x = batch self.metric_step(x.sum()) @@ -77,7 +81,7 @@ def test_metric_lightning_log(tmpdir): default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, - max_epochs=1, + max_epochs=2, log_every_n_steps=1, weights_summary=None, )