Update test for logging a metric object and state reset (#4825)
* update test * docstring Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
This commit is contained in:
parent
e971437551
commit
fb0278a457
|
@ -53,6 +53,7 @@ def test_metric_lightning(tmpdir):
|
|||
|
||||
|
||||
def test_metric_lightning_log(tmpdir):
|
||||
""" Test logging a metric object and that the metric state gets reset after each epoch."""
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -60,6 +61,9 @@ def test_metric_lightning_log(tmpdir):
|
|||
self.metric_epoch = SumMetric()
|
||||
self.sum = 0.0
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.sum = 0.0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = batch
|
||||
self.metric_step(x.sum())
|
||||
|
@ -77,7 +81,7 @@ def test_metric_lightning_log(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_epochs=2,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue