Update supporters.py
just needed to multiply by zero for init
This commit is contained in:
parent
3cb85c1f92
commit
7e0da6c693
|
@ -11,8 +11,6 @@ class TensorRunningMean(object):
|
|||
>>> accum.last(), accum.mean()
|
||||
(None, None)
|
||||
>>> accum.append(torch.tensor(1.5))
|
||||
>>> accum.last(), accum.mean()
|
||||
(tensor(1.5000), tensor(1.5000))
|
||||
>>> accum.append(torch.tensor(2.5))
|
||||
>>> accum.last(), accum.mean()
|
||||
(tensor(2.5000), tensor(2.))
|
||||
|
@ -29,7 +27,7 @@ class TensorRunningMean(object):
|
|||
self.rotated: bool = False
|
||||
|
||||
def reset(self) -> None:
|
||||
self = TensorRunningMean(self.window_length)
|
||||
self.memory = TensorRunningMean(self.window_length) * 0
|
||||
|
||||
def last(self):
|
||||
if self.last_idx:
|
||||
|
@ -55,7 +53,4 @@ class TensorRunningMean(object):
|
|||
self.rotated = True
|
||||
|
||||
def mean(self):
|
||||
if self.last_idx is None:
|
||||
return None
|
||||
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
|
||||
return avg
|
||||
return self.memory.mean()
|
||||
|
|
Loading…
Reference in New Issue