diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py new file mode 100644 index 0000000000..9260720ec3 --- /dev/null +++ b/pytorch_lightning/trainer/ignored_warnings.py @@ -0,0 +1,14 @@ +import warnings + + +def ignore_scalar_return_in_dp(): + # Users get confused by this warning so we silence it + m_1 = """ + Was asked to gather along dimension 0, but all + input tensors were scalars; will instead unsqueeze + and return a vector. + """ + warnings.filterwarnings('ignore', message=m_1) + + +ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6f8bd6005f..375d332841 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,6 +22,7 @@ from pytorch_lightning.pt_overrides.override_data_parallel import ( from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities.debugging import MisconfigurationException import pdb +from pytorch_lightning.trainer import ignored_warnings try: from apex import amp