diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index b7a2ec72b4..59e57f2b3a 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -11,8 +11,6 @@ class TensorRunningMean(object): >>> accum.last(), accum.mean() (None, None) >>> accum.append(torch.tensor(1.5)) - >>> accum.last(), accum.mean() - (tensor(1.5000), tensor(1.5000)) >>> accum.append(torch.tensor(2.5)) >>> accum.last(), accum.mean() (tensor(2.5000), tensor(2.)) @@ -29,7 +27,7 @@ class TensorRunningMean(object): self.rotated: bool = False def reset(self) -> None: - self = TensorRunningMean(self.window_length) + self.memory = TensorRunningMean(self.window_length) * 0 def last(self): if self.last_idx: @@ -55,7 +53,4 @@ class TensorRunningMean(object): self.rotated = True def mean(self): - if self.last_idx is None: - return None - avg = self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean() - return avg + return self.memory.mean()