Fix reset TensorRunningAccum (#5106)
* Fix reset TensorRunningAccum * add test for TensorRunningAccum's reset method * fix CI failed due to PEP8 Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
151d86e40b
commit
1d13943605
|
@ -56,7 +56,7 @@ class TensorRunningAccum(object):
|
|||
|
||||
def reset(self) -> None:
|
||||
"""Empty the accumulator."""
|
||||
self = TensorRunningAccum(self.window_length)
|
||||
self.__init__(self.window_length)
|
||||
|
||||
def last(self):
|
||||
"""Get the last added element."""
|
||||
|
|
|
@ -17,10 +17,33 @@ import pytest
|
|||
import torch
|
||||
|
||||
from torch.utils.data import TensorDataset
|
||||
from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator
|
||||
from pytorch_lightning.trainer.supporters import (
|
||||
CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator, TensorRunningAccum)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
||||
def test_tensor_running_accum_reset():
|
||||
""" Test that reset would set all attributes to the initialization state """
|
||||
|
||||
window_length = 10
|
||||
|
||||
accum = TensorRunningAccum(window_length=window_length)
|
||||
assert accum.last() is None
|
||||
assert accum.mean() is None
|
||||
|
||||
accum.append(torch.tensor(1.5))
|
||||
assert accum.last() == torch.tensor(1.5)
|
||||
assert accum.mean() == torch.tensor(1.5)
|
||||
|
||||
accum.reset()
|
||||
assert accum.window_length == window_length
|
||||
assert accum.memory is None
|
||||
assert accum.current_idx == 0
|
||||
assert accum.last_idx is None
|
||||
assert not accum.rotated
|
||||
|
||||
|
||||
def test_cycle_iterator():
|
||||
"""Test the cycling function of `CycleIterator`"""
|
||||
iterator = CycleIterator(range(100), 1000)
|
||||
|
|
Loading…
Reference in New Issue