This commit is contained in:
J. Borovec 2020-03-30 22:47:30 +02:00
parent 915cb09b02
commit 3cb85c1f92
1 changed files with 3 additions and 1 deletions

View File

@ -11,6 +11,8 @@ class TensorRunningMean(object):
>>> accum.last(), accum.mean() >>> accum.last(), accum.mean()
(None, None) (None, None)
>>> accum.append(torch.tensor(1.5)) >>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5)) >>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean() >>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.)) (tensor(2.5000), tensor(2.))
@ -53,7 +55,7 @@ class TensorRunningMean(object):
self.rotated = True self.rotated = True
def mean(self): def mean(self):
if not self.last_idx: if self.last_idx is None:
return None return None
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean() avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
return avg return avg