diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0b0d04bc1e..b7a2ec72b4 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -11,6 +11,8 @@ class TensorRunningMean(object): >>> accum.last(), accum.mean() (None, None) >>> accum.append(torch.tensor(1.5)) + >>> accum.last(), accum.mean() + (tensor(1.5000), tensor(1.5000)) >>> accum.append(torch.tensor(2.5)) >>> accum.last(), accum.mean() (tensor(2.5000), tensor(2.)) @@ -53,7 +55,7 @@ class TensorRunningMean(object): self.rotated = True def mean(self): - if not self.last_idx: + if self.last_idx is None: return None avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean() return avg