Update supporters.py
just needed to multiply by zero for init
This commit is contained in:
parent
3cb85c1f92
commit
7e0da6c693
|
@ -11,8 +11,6 @@ class TensorRunningMean(object):
|
||||||
>>> accum.last(), accum.mean()
|
>>> accum.last(), accum.mean()
|
||||||
(None, None)
|
(None, None)
|
||||||
>>> accum.append(torch.tensor(1.5))
|
>>> accum.append(torch.tensor(1.5))
|
||||||
>>> accum.last(), accum.mean()
|
|
||||||
(tensor(1.5000), tensor(1.5000))
|
|
||||||
>>> accum.append(torch.tensor(2.5))
|
>>> accum.append(torch.tensor(2.5))
|
||||||
>>> accum.last(), accum.mean()
|
>>> accum.last(), accum.mean()
|
||||||
(tensor(2.5000), tensor(2.))
|
(tensor(2.5000), tensor(2.))
|
||||||
|
@ -29,7 +27,7 @@ class TensorRunningMean(object):
|
||||||
self.rotated: bool = False
|
self.rotated: bool = False
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self = TensorRunningMean(self.window_length)
|
self.memory = TensorRunningMean(self.window_length) * 0
|
||||||
|
|
||||||
def last(self):
|
def last(self):
|
||||||
if self.last_idx:
|
if self.last_idx:
|
||||||
|
@ -55,7 +53,4 @@ class TensorRunningMean(object):
|
||||||
self.rotated = True
|
self.rotated = True
|
||||||
|
|
||||||
def mean(self):
|
def mean(self):
|
||||||
if self.last_idx is None:
|
return self.memory.mean()
|
||||||
return None
|
|
||||||
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
|
|
||||||
return avg
|
|
||||||
|
|
Loading…
Reference in New Issue