From 1d1394360574409d44637869f0ce70eea79dfa5c Mon Sep 17 00:00:00 2001 From: Loi Ly Date: Wed, 16 Dec 2020 13:44:30 +0700 Subject: [PATCH] Fix reset TensorRunningAccum (#5106) * Fix reset TensorRunningAccum * add test for TensorRunningAccum's reset method * fix CI failed due to PEP8 Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/supporters.py | 2 +- tests/trainer/test_supporters.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index db51fb8014..04fa3f4cc8 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -56,7 +56,7 @@ class TensorRunningAccum(object): def reset(self) -> None: """Empty the accumulator.""" - self = TensorRunningAccum(self.window_length) + self.__init__(self.window_length) def last(self): """Get the last added element.""" diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 6195d7ddeb..b1b0db749e 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -17,10 +17,33 @@ import pytest import torch from torch.utils.data import TensorDataset -from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator +from pytorch_lightning.trainer.supporters import ( + CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator, TensorRunningAccum) from pytorch_lightning.utilities.exceptions import MisconfigurationException + +def test_tensor_running_accum_reset(): + """ Test that reset would set all attributes to the initialization state """ + + window_length = 10 + + accum = TensorRunningAccum(window_length=window_length) + assert accum.last() is None + assert accum.mean() is None + + accum.append(torch.tensor(1.5)) + assert accum.last() == torch.tensor(1.5) + assert accum.mean() == torch.tensor(1.5) + + accum.reset() + assert accum.window_length == window_length + assert accum.memory is None + assert accum.current_idx == 0 + assert accum.last_idx is None + assert not accum.rotated + + def test_cycle_iterator(): """Test the cycling function of `CycleIterator`""" iterator = CycleIterator(range(100), 1000)