fix none
This commit is contained in:
parent
915cb09b02
commit
3cb85c1f92
|
@ -11,6 +11,8 @@ class TensorRunningMean(object):
|
||||||
>>> accum.last(), accum.mean()
|
>>> accum.last(), accum.mean()
|
||||||
(None, None)
|
(None, None)
|
||||||
>>> accum.append(torch.tensor(1.5))
|
>>> accum.append(torch.tensor(1.5))
|
||||||
|
>>> accum.last(), accum.mean()
|
||||||
|
(tensor(1.5000), tensor(1.5000))
|
||||||
>>> accum.append(torch.tensor(2.5))
|
>>> accum.append(torch.tensor(2.5))
|
||||||
>>> accum.last(), accum.mean()
|
>>> accum.last(), accum.mean()
|
||||||
(tensor(2.5000), tensor(2.))
|
(tensor(2.5000), tensor(2.))
|
||||||
|
@ -53,7 +55,7 @@ class TensorRunningMean(object):
|
||||||
self.rotated = True
|
self.rotated = True
|
||||||
|
|
||||||
def mean(self):
|
def mean(self):
|
||||||
if not self.last_idx:
|
if self.last_idx is None:
|
||||||
return None
|
return None
|
||||||
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
|
avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
|
||||||
return avg
|
return avg
|
||||||
|
|
Loading…
Reference in New Issue