From fb0278a457be64a25cd2cf7e6d29c6ae3846ccc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Nov 2020 11:28:02 +0100 Subject: [PATCH] Update test for logging a metric object and state reset (#4825) * update test * docstring Co-authored-by: Ananya Harsh Jha --- tests/metrics/test_metric_lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index a35562327d..1622217c24 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -53,6 +53,7 @@ def test_metric_lightning(tmpdir): def test_metric_lightning_log(tmpdir): + """ Test logging a metric object and that the metric state gets reset after each epoch.""" class TestModel(BoringModel): def __init__(self): super().__init__() @@ -60,6 +61,9 @@ def test_metric_lightning_log(tmpdir): self.metric_epoch = SumMetric() self.sum = 0.0 + def on_epoch_start(self): + self.sum = 0.0 + def training_step(self, batch, batch_idx): x = batch self.metric_step(x.sum()) @@ -77,7 +81,7 @@ def test_metric_lightning_log(tmpdir): default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, - max_epochs=1, + max_epochs=2, log_every_n_steps=1, weights_summary=None, )