Update test for logging a metric object and state reset (#4825)

* update test

* docstring

Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
This commit is contained in:
Adrian Wälchli 2020-11-24 11:28:02 +01:00 committed by GitHub
parent e971437551
commit fb0278a457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -53,6 +53,7 @@ def test_metric_lightning(tmpdir):
def test_metric_lightning_log(tmpdir):
""" Test logging a metric object and that the metric state gets reset after each epoch."""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
@ -60,6 +61,9 @@ def test_metric_lightning_log(tmpdir):
self.metric_epoch = SumMetric()
self.sum = 0.0
def on_epoch_start(self):
self.sum = 0.0
def training_step(self, batch, batch_idx):
x = batch
self.metric_step(x.sum())
@ -77,7 +81,7 @@ def test_metric_lightning_log(tmpdir):
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
)