From 5b342f14a69820b321100e5ef452c7ab212b9414 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 28 Feb 2022 18:07:05 +0530 Subject: [PATCH] fix to avoid common hook warning if no hook is overridden (#12131) --- CHANGELOG.md | 22 +++++++++++-------- .../trainer/connectors/data_connector.py | 9 ++++---- .../trainer/connectors/test_data_connector.py | 13 +++++------ 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0dea9cd0f..af007c68a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -317,6 +317,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448)) + +- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576)) + + ### Deprecated - Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) @@ -678,6 +682,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing `_ddp_params_and_buffers_to_ignore` ([#11949](https://github.com/PyTorchLightning/pytorch-lightning/pull/11949)) +- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827)) + + +- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406)) + + +- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131)) + + ## [1.5.10] - 2022-02-08 ### Fixed @@ -694,12 +707,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481)) -- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827)) - - -- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406)) - - ## [1.5.9] - 2022-01-20 ### Fixed @@ -714,9 +721,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) -- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576)) - - ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b79b095fec..b0cf6a95fa 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -599,8 +599,9 @@ class _DataHookSelector: ) return getattr(self.datamodule, hook_name) - warning_cache.warn( - f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" - " `LightningDataModule`. It will use the implementation from `LightningModule` instance." - ) + if is_overridden(hook_name, self.model): + warning_cache.warn( + f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" + " `LightningDataModule`. It will use the implementation from `LightningModule` instance." + ) return getattr(self.model, hook_name) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index bb618dfa09..e22e846600 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -20,9 +20,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.warnings import PossibleUserWarning -from tests.deprecated_api import no_warning_call from tests.helpers import BoringDataModule, BoringModel from tests.helpers.boring_model import RandomDataset +from tests.helpers.utils import no_warning_call class NoDataLoaderModel(BoringModel): @@ -80,12 +80,13 @@ class TestDataHookSelector: return batch def reset_instances(self): + warning_cache.clear() return BoringDataModule(), BoringModel(), Trainer() def test_no_datamodule_no_overridden(self, hook_name): model, _, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=None) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -93,7 +94,7 @@ class TestDataHookSelector: def test_with_datamodule_no_overridden(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -101,7 +102,7 @@ class TestDataHookSelector: def test_override_model_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -110,7 +111,7 @@ class TestDataHookSelector: model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(dm, hook_name, self.overridden_func) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(dm, hook_name) @@ -123,7 +124,6 @@ class TestDataHookSelector: with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) - warning_cache.clear() assert hook == getattr(dm, hook_name) def test_with_datamodule_override_model(self, hook_name): @@ -133,7 +133,6 @@ class TestDataHookSelector: with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) - warning_cache.clear() assert hook == getattr(model, hook_name)