Make unimplemented dataloader hooks raise `NotImplementedError` (#9161)
This commit is contained in:
parent
3fd77cbde6
commit
c993d0ce33
|
@ -143,6 +143,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
|
||||
|
||||
|
||||
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional
|
|||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
|
||||
|
||||
|
@ -540,7 +540,7 @@ class DataHooks:
|
|||
return {'mnist': mnist_loader, 'cifar': cifar_loader}
|
||||
|
||||
"""
|
||||
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def test_dataloader(self) -> EVAL_DATALOADERS:
|
||||
r"""
|
||||
|
@ -602,6 +602,7 @@ class DataHooks:
|
|||
In the case where you return multiple test dataloaders, the :meth:`test_step`
|
||||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
raise NotImplementedError("`test_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def val_dataloader(self) -> EVAL_DATALOADERS:
|
||||
r"""
|
||||
|
@ -654,6 +655,7 @@ class DataHooks:
|
|||
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
|
||||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
raise NotImplementedError("`val_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
||||
r"""
|
||||
|
@ -679,6 +681,7 @@ class DataHooks:
|
|||
In the case where you return multiple prediction dataloaders, the :meth:`predict`
|
||||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def on_train_dataloader(self) -> None:
|
||||
"""Called before requesting the train dataloader."""
|
||||
|
|
|
@ -480,8 +480,10 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
|||
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
|
||||
dm.train_dataloader()
|
||||
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
|
||||
assert dm.val_dataloader() is None
|
||||
assert dm.test_dataloader() is None
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = dm.val_dataloader()
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = dm.test_dataloader()
|
||||
|
||||
train_ds_sequence = [ds(), ds()]
|
||||
dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0)
|
||||
|
@ -493,8 +495,10 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
|||
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
|
||||
]
|
||||
)
|
||||
assert dm.val_dataloader() is None
|
||||
assert dm.test_dataloader() is None
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = dm.val_dataloader()
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = dm.test_dataloader()
|
||||
|
||||
valid_ds = ds()
|
||||
test_ds = ds()
|
||||
|
@ -504,7 +508,8 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
|||
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
||||
dm.test_dataloader()
|
||||
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
||||
assert dm.train_dataloader() is None
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = dm.train_dataloader()
|
||||
|
||||
valid_dss = [ds(), ds()]
|
||||
test_dss = [ds(), ds()]
|
||||
|
|
Loading…
Reference in New Issue