Fix reset TensorRunningAccum (#5106)

* Fix reset TensorRunningAccum

* add test for TensorRunningAccum's reset method

* fix CI failed due to PEP8

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Loi Ly 2020-12-16 13:44:30 +07:00 committed by Jirka Borovec
parent 151d86e40b
commit 1d13943605
2 changed files with 25 additions and 2 deletions

View File

@ -56,7 +56,7 @@ class TensorRunningAccum(object):
def reset(self) -> None: def reset(self) -> None:
"""Empty the accumulator.""" """Empty the accumulator."""
self = TensorRunningAccum(self.window_length) self.__init__(self.window_length)
def last(self): def last(self):
"""Get the last added element.""" """Get the last added element."""

View File

@ -17,10 +17,33 @@ import pytest
import torch import torch
from torch.utils.data import TensorDataset from torch.utils.data import TensorDataset
from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator from pytorch_lightning.trainer.supporters import (
CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator, TensorRunningAccum)
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_tensor_running_accum_reset():
""" Test that reset would set all attributes to the initialization state """
window_length = 10
accum = TensorRunningAccum(window_length=window_length)
assert accum.last() is None
assert accum.mean() is None
accum.append(torch.tensor(1.5))
assert accum.last() == torch.tensor(1.5)
assert accum.mean() == torch.tensor(1.5)
accum.reset()
assert accum.window_length == window_length
assert accum.memory is None
assert accum.current_idx == 0
assert accum.last_idx is None
assert not accum.rotated
def test_cycle_iterator(): def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`""" """Test the cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100), 1000) iterator = CycleIterator(range(100), 1000)