diff --git a/CHANGELOG.md b/CHANGELOG.md index 14f9c010fa..b781d56000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) - Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191)) - Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) +- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309)) ## [0.7.1] - 2020-03-07 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ee18a40436..2cfb7b44e7 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1525,9 +1525,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): Dictionary with the items to be displayed in the progress bar. """ # call .item() only once but store elements without graphs - running_training_loss = self.trainer.running_loss.mean().cpu().item() + running_train_loss = self.trainer.running_loss.mean() + avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN') tqdm_dict = { - 'loss': '{:.3f}'.format(running_training_loss) + 'loss': '{:.3f}'.format(avg_training_loss) } if self.trainer.truncated_bptt_steps is not None: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py new file mode 100644 index 0000000000..c7fd4d6631 --- /dev/null +++ b/pytorch_lightning/trainer/supporters.py @@ -0,0 +1,58 @@ +import torch + + +class TensorRunningMean(object): + """ + Tracks a running mean without graph references. + Round robbin for the mean + + Examples: + >>> accum = TensorRunningMean(5) + >>> accum.last(), accum.mean() + (None, None) + >>> accum.append(torch.tensor(1.5)) + >>> accum.last(), accum.mean() + (tensor(1.5000), tensor(1.5000)) + >>> accum.append(torch.tensor(2.5)) + >>> accum.last(), accum.mean() + (tensor(2.5000), tensor(2.)) + >>> accum.reset() + >>> _= [accum.append(torch.tensor(i)) for i in range(13)] + >>> accum.last(), accum.mean() + (tensor(12.), tensor(10.)) + """ + def __init__(self, window_length: int): + self.window_length = window_length + self.memory = torch.Tensor(self.window_length) + self.current_idx: int = 0 + self.last_idx: int = None + self.rotated: bool = False + + def reset(self) -> None: + self = TensorRunningMean(self.window_length) + + def last(self): + if self.last_idx is not None: + return self.memory[self.last_idx] + + def append(self, x): + # map proper type for memory if they don't match + if self.memory.type() != x.type(): + self.memory.type_as(x) + + # store without grads + with torch.no_grad(): + self.memory[self.current_idx] = x + self.last_idx = self.current_idx + + # increase index + self.current_idx += 1 + + # reset index when hit limit of tensor + self.current_idx = self.current_idx % self.window_length + if self.current_idx == 0: + self.rotated = True + + def mean(self): + if self.last_idx is not None: + return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean() diff --git a/pytorch_lightning/trainer/supporting_classes.py b/pytorch_lightning/trainer/supporting_classes.py deleted file mode 100644 index 7f2b0824a6..0000000000 --- a/pytorch_lightning/trainer/supporting_classes.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch - - -class TensorRunningMean(object): - """ - Tracks a running mean without graph references. - Round robbin for the mean - """ - def __init__(self, window_length): - self.window_length = window_length - self.reset() - self.last_idx = 0 - - def reset(self): - self.memory = torch.Tensor(self.window_length) - self.current_idx = 0 - - def last(self): - return self.memory[self.last_idx] - - def append(self, x): - # map proper type for memory if they don't match - if self.memory.type() != x.type(): - self.memory.type_as(x) - - # store without grads - with torch.no_grad(): - self.memory[self.current_idx] = x - self.last_idx = self.current_idx - - # increase index - self.current_idx += 1 - - # reset index when hit limit of tensor - if self.current_idx >= self.window_length: - self.current_idx = 0 - - def mean(self): - return self.memory.mean() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2dc6c76598..6ba50bf2f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -34,7 +34,7 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException -from pytorch_lightning.trainer.supporting_classes import TensorRunningMean +from pytorch_lightning.trainer.supporters import TensorRunningMean try: from apex import amp diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cc44882ea5..71896f8f75 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -146,7 +146,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.debugging import MisconfigurationException -from pytorch_lightning.trainer.supporting_classes import TensorRunningMean +from pytorch_lightning.trainer.supporters import TensorRunningMean try: from apex import amp diff --git a/tests/collect_env_details.py b/tests/collect_env_details.py index 957397f3bb..aaf16a104e 100644 --- a/tests/collect_env_details.py +++ b/tests/collect_env_details.py @@ -48,7 +48,7 @@ def info_system(): def info_cuda(): return { - 'GPU': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]), + 'GPU': [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], # 'nvidia_driver': get_nvidia_driver_version(run_lambda), 'available': torch.cuda.is_available(), 'version': torch.version.cuda,