deprecate `on_{train/val/test/predict}_dataloader()` from DataHooks (#9098)
Co-authored-by: Sean Naren <sean@grid.ai> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
c993d0ce33
commit
1657588f35
|
@ -161,6 +161,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
|
||||
|
||||
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
|
||||
|
||||
-
|
||||
|
||||
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
|
||||
class ModelHooks:
|
||||
|
@ -684,16 +685,52 @@ class DataHooks:
|
|||
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def on_train_dataloader(self) -> None:
|
||||
"""Called before requesting the train dataloader."""
|
||||
"""Called before requesting the train dataloader.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
|
||||
Please use :meth:`train_dataloader()` directly.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `train_dataloader()` directly."
|
||||
)
|
||||
|
||||
def on_val_dataloader(self) -> None:
|
||||
"""Called before requesting the val dataloader."""
|
||||
"""Called before requesting the val dataloader.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
|
||||
Please use :meth:`val_dataloader()` directly.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `val_dataloader()` directly."
|
||||
)
|
||||
|
||||
def on_test_dataloader(self) -> None:
|
||||
"""Called before requesting the test dataloader."""
|
||||
"""Called before requesting the test dataloader.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
|
||||
Please use :meth:`test_dataloader()` directly.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `test_dataloader()` directly."
|
||||
)
|
||||
|
||||
def on_predict_dataloader(self) -> None:
|
||||
"""Called before requesting the predict dataloader."""
|
||||
"""Called before requesting the predict dataloader.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
|
||||
Please use :meth:`predict_dataloader()` directly.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `predict_dataloader()` directly."
|
||||
)
|
||||
|
||||
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
@ -75,6 +75,25 @@ class ConfigValidator:
|
|||
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
|
||||
)
|
||||
|
||||
# ----------------------------------------------
|
||||
# verify model does not have
|
||||
# - on_train_dataloader
|
||||
# - on_val_dataloader
|
||||
# ----------------------------------------------
|
||||
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
|
||||
if has_on_train_dataloader:
|
||||
rank_zero_deprecation(
|
||||
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `train_dataloader()` directly."
|
||||
)
|
||||
|
||||
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
|
||||
if has_on_val_dataloader:
|
||||
rank_zero_deprecation(
|
||||
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `val_dataloader()` directly."
|
||||
)
|
||||
|
||||
trainer = self.trainer
|
||||
|
||||
trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
|
||||
|
@ -102,10 +121,39 @@ class ConfigValidator:
|
|||
if has_step and not has_loader:
|
||||
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
|
||||
|
||||
# ----------------------------------------------
|
||||
# verify model does not have
|
||||
# - on_val_dataloader
|
||||
# - on_test_dataloader
|
||||
# ----------------------------------------------
|
||||
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
|
||||
if has_on_val_dataloader:
|
||||
rank_zero_deprecation(
|
||||
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `val_dataloader()` directly."
|
||||
)
|
||||
|
||||
has_on_test_dataloader = is_overridden("on_test_dataloader", model)
|
||||
if has_on_test_dataloader:
|
||||
rank_zero_deprecation(
|
||||
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `test_dataloader()` directly."
|
||||
)
|
||||
|
||||
def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
|
||||
has_predict_dataloader = is_overridden("predict_dataloader", model)
|
||||
if not has_predict_dataloader:
|
||||
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
|
||||
# ----------------------------------------------
|
||||
# verify model does not have
|
||||
# - on_predict_dataloader
|
||||
# ----------------------------------------------
|
||||
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
|
||||
if has_on_predict_dataloader:
|
||||
rank_zero_deprecation(
|
||||
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
" Please use `predict_dataloader()` directly."
|
||||
)
|
||||
|
||||
def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None:
|
||||
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""
|
||||
|
|
|
@ -91,6 +91,27 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
|
|||
_ = Trainer(prepare_data_per_node=False)
|
||||
|
||||
|
||||
def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
|
||||
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
model.on_train_dataloader()
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
model.on_val_dataloader()
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
model.on_test_dataloader()
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
model.on_predict_dataloader()
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
|
||||
def test_v1_7_0_test_tube_logger(_, tmpdir):
|
||||
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
|
||||
|
|
|
@ -1868,3 +1868,30 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
|
|||
trainer.test(model)
|
||||
with pytest.raises(Exception, match=r"Error during predict"), patch("pytorch_lightning.Trainer._on_exception"):
|
||||
trainer.predict(model, model.val_dataloader(), return_predictions=False)
|
||||
|
||||
|
||||
def test_overridden_on_dataloaders(tmpdir):
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.validate(model)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.test(model)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.predict(model)
|
||||
|
|
Loading…
Reference in New Issue