deprecate `on_{train/val/test/predict}_dataloader()` from DataHooks (#9098)

Co-authored-by: Sean Naren <sean@grid.ai>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Ning 2021-08-28 10:27:56 -07:00 committed by GitHub
parent c993d0ce33
commit 1657588f35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 5 deletions

View File

@ -161,6 +161,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
-
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))

View File

@ -20,6 +20,7 @@ from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
class ModelHooks:
@ -684,16 +685,52 @@ class DataHooks:
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
def on_train_dataloader(self) -> None:
"""Called before requesting the train dataloader."""
"""Called before requesting the train dataloader.
.. deprecated:: v1.5
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`train_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)
def on_val_dataloader(self) -> None:
"""Called before requesting the val dataloader."""
"""Called before requesting the val dataloader.
.. deprecated:: v1.5
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`val_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)
def on_test_dataloader(self) -> None:
"""Called before requesting the test dataloader."""
"""Called before requesting the test dataloader.
.. deprecated:: v1.5
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`test_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)
def on_predict_dataloader(self) -> None:
"""Called before requesting the predict dataloader."""
"""Called before requesting the predict dataloader.
.. deprecated:: v1.5
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`predict_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
"""

View File

@ -13,7 +13,7 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
@ -75,6 +75,25 @@ class ConfigValidator:
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
)
# ----------------------------------------------
# verify model does not have
# - on_train_dataloader
# - on_val_dataloader
# ----------------------------------------------
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
if has_on_train_dataloader:
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)
trainer = self.trainer
trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
@ -102,10 +121,39 @@ class ConfigValidator:
if has_step and not has_loader:
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
# ----------------------------------------------
# verify model does not have
# - on_val_dataloader
# - on_test_dataloader
# ----------------------------------------------
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)
has_on_test_dataloader = is_overridden("on_test_dataloader", model)
if has_on_test_dataloader:
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)
def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
has_predict_dataloader = is_overridden("predict_dataloader", model)
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
# verify model does not have
# - on_predict_dataloader
# ----------------------------------------------
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
if has_on_predict_dataloader:
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)
def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""

View File

@ -91,6 +91,27 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
_ = Trainer(prepare_data_per_node=False)
def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_train_dataloader()
with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_val_dataloader()
with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_test_dataloader()
with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_predict_dataloader()
@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
def test_v1_7_0_test_tube_logger(_, tmpdir):
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):

View File

@ -1868,3 +1868,30 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
trainer.test(model)
with pytest.raises(Exception, match=r"Error during predict"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.predict(model, model.val_dataloader(), return_predictions=False)
def test_overridden_on_dataloaders(tmpdir):
model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.validate(model)
with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.test(model)
with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.predict(model)