From 8b2a2aeda3066fe30cc496a58368a523ef90ad9b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 26 Sep 2019 13:20:54 -0400 Subject: [PATCH] Dim 0 warning (#256) * added ignore warnings module * added ignore warnings module * Fixes #249 * Update ignored_warnings.py --- pytorch_lightning/trainer/ignored_warnings.py | 14 ++++++++++++++ pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 15 insertions(+) create mode 100644 pytorch_lightning/trainer/ignored_warnings.py diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py new file mode 100644 index 0000000000..9260720ec3 --- /dev/null +++ b/pytorch_lightning/trainer/ignored_warnings.py @@ -0,0 +1,14 @@ +import warnings + + +def ignore_scalar_return_in_dp(): + # Users get confused by this warning so we silence it + m_1 = """ + Was asked to gather along dimension 0, but all + input tensors were scalars; will instead unsqueeze + and return a vector. + """ + warnings.filterwarnings('ignore', message=m_1) + + +ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6f8bd6005f..375d332841 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,6 +22,7 @@ from pytorch_lightning.pt_overrides.override_data_parallel import ( from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities.debugging import MisconfigurationException import pdb +from pytorch_lightning.trainer import ignored_warnings try: from apex import amp