fix to avoid common hook warning if no hook is overridden (#12131)
This commit is contained in:
parent
4daa7ce325
commit
5b342f14a6
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -317,6 +317,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))
|
||||
|
||||
|
||||
- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
|
||||
|
@ -678,6 +682,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed passing `_ddp_params_and_buffers_to_ignore` ([#11949](https://github.com/PyTorchLightning/pytorch-lightning/pull/11949))
|
||||
|
||||
|
||||
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))
|
||||
|
||||
|
||||
- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))
|
||||
|
||||
|
||||
- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131))
|
||||
|
||||
|
||||
## [1.5.10] - 2022-02-08
|
||||
|
||||
### Fixed
|
||||
|
@ -694,12 +707,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
|
||||
|
||||
|
||||
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))
|
||||
|
||||
|
||||
- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))
|
||||
|
||||
|
||||
## [1.5.9] - 2022-01-20
|
||||
|
||||
### Fixed
|
||||
|
@ -714,9 +721,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
|
||||
|
||||
|
||||
- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))
|
||||
|
||||
|
||||
## [1.5.8] - 2022-01-05
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -599,8 +599,9 @@ class _DataHookSelector:
|
|||
)
|
||||
return getattr(self.datamodule, hook_name)
|
||||
|
||||
warning_cache.warn(
|
||||
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
|
||||
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
|
||||
)
|
||||
if is_overridden(hook_name, self.model):
|
||||
warning_cache.warn(
|
||||
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
|
||||
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
|
||||
)
|
||||
return getattr(self.model, hook_name)
|
||||
|
|
|
@ -20,9 +20,9 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
from tests.deprecated_api import no_warning_call
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.boring_model import RandomDataset
|
||||
from tests.helpers.utils import no_warning_call
|
||||
|
||||
|
||||
class NoDataLoaderModel(BoringModel):
|
||||
|
@ -80,12 +80,13 @@ class TestDataHookSelector:
|
|||
return batch
|
||||
|
||||
def reset_instances(self):
|
||||
warning_cache.clear()
|
||||
return BoringDataModule(), BoringModel(), Trainer()
|
||||
|
||||
def test_no_datamodule_no_overridden(self, hook_name):
|
||||
model, _, trainer = self.reset_instances()
|
||||
trainer._data_connector.attach_datamodule(model, datamodule=None)
|
||||
with no_warning_call(match="have overridden `{hook_name}` in both"):
|
||||
with no_warning_call(match=f"have overridden `{hook_name}` in"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
assert hook == getattr(model, hook_name)
|
||||
|
@ -93,7 +94,7 @@ class TestDataHookSelector:
|
|||
def test_with_datamodule_no_overridden(self, hook_name):
|
||||
model, dm, trainer = self.reset_instances()
|
||||
trainer._data_connector.attach_datamodule(model, datamodule=dm)
|
||||
with no_warning_call(match="have overridden `{hook_name}` in both"):
|
||||
with no_warning_call(match=f"have overridden `{hook_name}` in"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
assert hook == getattr(model, hook_name)
|
||||
|
@ -101,7 +102,7 @@ class TestDataHookSelector:
|
|||
def test_override_model_hook(self, hook_name):
|
||||
model, dm, trainer = self.reset_instances()
|
||||
trainer._data_connector.attach_datamodule(model, datamodule=dm)
|
||||
with no_warning_call(match="have overridden `{hook_name}` in both"):
|
||||
with no_warning_call(match=f"have overridden `{hook_name}` in"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
assert hook == getattr(model, hook_name)
|
||||
|
@ -110,7 +111,7 @@ class TestDataHookSelector:
|
|||
model, dm, trainer = self.reset_instances()
|
||||
trainer._data_connector.attach_datamodule(model, datamodule=dm)
|
||||
setattr(dm, hook_name, self.overridden_func)
|
||||
with no_warning_call(match="have overridden `{hook_name}` in both"):
|
||||
with no_warning_call(match=f"have overridden `{hook_name}` in"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
assert hook == getattr(dm, hook_name)
|
||||
|
@ -123,7 +124,6 @@ class TestDataHookSelector:
|
|||
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
warning_cache.clear()
|
||||
assert hook == getattr(dm, hook_name)
|
||||
|
||||
def test_with_datamodule_override_model(self, hook_name):
|
||||
|
@ -133,7 +133,6 @@ class TestDataHookSelector:
|
|||
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
|
||||
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
|
||||
|
||||
warning_cache.clear()
|
||||
assert hook == getattr(model, hook_name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue