From f9e050c5e5f9a7adc1a01f18a3161b8669301343 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 7 May 2021 14:02:44 -0700 Subject: [PATCH] Move DP warning suppression to the DataParallel Plugin (#7421) --- CHANGELOG.md | 3 +++ docs/source/governance.rst | 2 -- pytorch_lightning/overrides/data_parallel.py | 10 +++++++ pytorch_lightning/trainer/ignored_warnings.py | 27 ------------------- 4 files changed, 13 insertions(+), 29 deletions(-) delete mode 100644 pytorch_lightning/trainer/ignored_warnings.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 899f9f74b9..adac6ae8a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) +- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) + + ### Deprecated diff --git a/docs/source/governance.rst b/docs/source/governance.rst index fac8b68e1d..5b1f9bd191 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -38,5 +38,3 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) - - diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index a9d312e9f6..272f4c6750 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -26,6 +26,15 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +def _ignore_scalar_return_in_dp(): + # Users get confused by this warning so we silence it + warnings.filterwarnings( + 'ignore', + message='Was asked to gather along dimension 0, but all input tensors were scalars;' + ' will instead unsqueeze and return a vector.' + ) + + class LightningDataParallel(DataParallel): def __init__(self, module: LightningModule, *args, **kwargs): @@ -70,6 +79,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): def __init__(self, pl_module: LightningModule): super().__init__(pl_module) + _ignore_scalar_return_in_dp() def forward(self, *inputs, **kwargs): self.update_replica_device_attributes(inputs) diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py deleted file mode 100644 index 894416d607..0000000000 --- a/pytorch_lightning/trainer/ignored_warnings.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - - -def ignore_scalar_return_in_dp(): - # Users get confused by this warning so we silence it - warnings.filterwarnings( - 'ignore', - message='Was asked to gather along dimension 0, but all input tensors were scalars;' - ' will instead unsqueeze and return a vector.' - ) - - -ignore_scalar_return_in_dp()