Update supporters.py

just needed to multiply by zero for init
This commit is contained in:
William Falcon 2020-03-30 16:49:04 -04:00 committed by GitHub
parent 3cb85c1f92
commit 7e0da6c693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 7 deletions

View File

@ -11,8 +11,6 @@ class TensorRunningMean(object):
>>> accum.last(), accum.mean() >>> accum.last(), accum.mean()
(None, None) (None, None)
>>> accum.append(torch.tensor(1.5)) >>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5)) >>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean() >>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.)) (tensor(2.5000), tensor(2.))
@ -29,7 +27,7 @@ class TensorRunningMean(object):
self.rotated: bool = False self.rotated: bool = False
def reset(self) -> None: def reset(self) -> None:
self = TensorRunningMean(self.window_length) self.memory = TensorRunningMean(self.window_length) * 0
def last(self): def last(self):
if self.last_idx: if self.last_idx:
@ -55,7 +53,4 @@ class TensorRunningMean(object):
self.rotated = True self.rotated = True
def mean(self): def mean(self):
if self.last_idx is None: return self.memory.mean()
return None
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
return avg