diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1f7f5a82a9..442da2274c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -19,6 +19,7 @@ import torch from torch.optim.optimizer import Optimizer from pytorch_lightning.utilities import move_data_to_device +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS @@ -490,7 +491,7 @@ class DataHooks: # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader} """ - raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer") + raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer") def test_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -544,7 +545,7 @@ class DataHooks: In the case where you return multiple test dataloaders, the :meth:`test_step` will have an argument ``dataloader_idx`` which matches the order here. """ - raise NotImplementedError("`test_dataloader` must be implemented to be used with the Lightning Trainer") + raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer") def val_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -595,7 +596,7 @@ class DataHooks: In the case where you return multiple validation dataloaders, the :meth:`validation_step` will have an argument ``dataloader_idx`` which matches the order here. """ - raise NotImplementedError("`val_dataloader` must be implemented to be used with the Lightning Trainer") + raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer") def predict_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -618,7 +619,9 @@ class DataHooks: In the case where you return multiple prediction dataloaders, the :meth:`predict_step` will have an argument ``dataloader_idx`` which matches the order here. """ - raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer") + raise MisconfigurationException( + "`predict_dataloader` must be implemented to be used with the Lightning Trainer" + ) def on_train_dataloader(self) -> None: """Called before requesting the train dataloader. diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index f015d96fef..4c337a44ca 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -377,9 +377,9 @@ def test_dm_init_from_datasets_dataloaders(iterable): with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): _ = dm.val_dataloader() - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): _ = dm.test_dataloader() train_ds_sequence = [ds(), ds()] @@ -392,9 +392,9 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), ] ) - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): _ = dm.val_dataloader() - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): _ = dm.test_dataloader() valid_ds = ds() @@ -405,7 +405,7 @@ def test_dm_init_from_datasets_dataloaders(iterable): dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) dm.test_dataloader() dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): _ = dm.train_dataloader() valid_dss = [ds(), ds()] diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 9429db07f6..c7fee3b0d5 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -18,7 +18,7 @@ import torch from torch import nn from torch.optim import Adam, SGD -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 @@ -26,6 +26,11 @@ from tests.helpers import BoringModel from tests.helpers.runif import RunIf +def test_lightning_module_not_abstract(): + """Test that the LightningModule can be instantiated (it is not an abstract class).""" + _ = LightningModule() + + def test_property_current_epoch(): """Test that the current_epoch in LightningModule is accessible via the Trainer.""" model = BoringModel()