fix to avoid common hook warning if no hook is overridden (#12131)

This commit is contained in:
Rohit Gupta 2022-02-28 18:07:05 +05:30 committed by GitHub
parent 4daa7ce325
commit 5b342f14a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 20 deletions

View File

@ -317,6 +317,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))
- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))
### Deprecated
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
@ -678,6 +682,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing `_ddp_params_and_buffers_to_ignore` ([#11949](https://github.com/PyTorchLightning/pytorch-lightning/pull/11949))
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))
- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))
- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131))
## [1.5.10] - 2022-02-08
### Fixed
@ -694,12 +707,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))
- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))
## [1.5.9] - 2022-01-20
### Fixed
@ -714,9 +721,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))
## [1.5.8] - 2022-01-05
### Fixed

View File

@ -599,8 +599,9 @@ class _DataHookSelector:
)
return getattr(self.datamodule, hook_name)
warning_cache.warn(
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
)
if is_overridden(hook_name, self.model):
warning_cache.warn(
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
)
return getattr(self.model, hook_name)

View File

@ -20,9 +20,9 @@ from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.deprecated_api import no_warning_call
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.boring_model import RandomDataset
from tests.helpers.utils import no_warning_call
class NoDataLoaderModel(BoringModel):
@ -80,12 +80,13 @@ class TestDataHookSelector:
return batch
def reset_instances(self):
warning_cache.clear()
return BoringDataModule(), BoringModel(), Trainer()
def test_no_datamodule_no_overridden(self, hook_name):
model, _, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=None)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
assert hook == getattr(model, hook_name)
@ -93,7 +94,7 @@ class TestDataHookSelector:
def test_with_datamodule_no_overridden(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
assert hook == getattr(model, hook_name)
@ -101,7 +102,7 @@ class TestDataHookSelector:
def test_override_model_hook(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
assert hook == getattr(model, hook_name)
@ -110,7 +111,7 @@ class TestDataHookSelector:
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
setattr(dm, hook_name, self.overridden_func)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
assert hook == getattr(dm, hook_name)
@ -123,7 +124,6 @@ class TestDataHookSelector:
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
warning_cache.clear()
assert hook == getattr(dm, hook_name)
def test_with_datamodule_override_model(self, hook_name):
@ -133,7 +133,6 @@ class TestDataHookSelector:
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
warning_cache.clear()
assert hook == getattr(model, hook_name)