From c993d0ce33d54ec51c5525476c95a19fe988c94b Mon Sep 17 00:00:00 2001 From: "B. Kerim Tshimanga" Date: Sat, 28 Aug 2021 09:07:47 -0700 Subject: [PATCH] Make unimplemented dataloader hooks raise `NotImplementedError` (#9161) --- CHANGELOG.md | 2 ++ pytorch_lightning/core/hooks.py | 7 +++++-- tests/core/test_datamodules.py | 15 ++++++++++----- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7f90520f8..e507980184 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,6 +143,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) +- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161)) + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 7ff2188534..24c15462eb 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional import torch from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn +from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS @@ -540,7 +540,7 @@ class DataHooks: return {'mnist': mnist_loader, 'cifar': cifar_loader} """ - rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") + raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer") def test_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -602,6 +602,7 @@ class DataHooks: In the case where you return multiple test dataloaders, the :meth:`test_step` will have an argument ``dataloader_idx`` which matches the order here. """ + raise NotImplementedError("`test_dataloader` must be implemented to be used with the Lightning Trainer") def val_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -654,6 +655,7 @@ class DataHooks: In the case where you return multiple validation dataloaders, the :meth:`validation_step` will have an argument ``dataloader_idx`` which matches the order here. """ + raise NotImplementedError("`val_dataloader` must be implemented to be used with the Lightning Trainer") def predict_dataloader(self) -> EVAL_DATALOADERS: r""" @@ -679,6 +681,7 @@ class DataHooks: In the case where you return multiple prediction dataloaders, the :meth:`predict` will have an argument ``dataloader_idx`` which matches the order here. """ + raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer") def on_train_dataloader(self) -> None: """Called before requesting the train dataloader.""" diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3bfe3aaa6c..fe51937e5a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -480,8 +480,10 @@ def test_dm_init_from_datasets_dataloaders(iterable): with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) - assert dm.val_dataloader() is None - assert dm.test_dataloader() is None + with pytest.raises(NotImplementedError): + _ = dm.val_dataloader() + with pytest.raises(NotImplementedError): + _ = dm.test_dataloader() train_ds_sequence = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0) @@ -493,8 +495,10 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), ] ) - assert dm.val_dataloader() is None - assert dm.test_dataloader() is None + with pytest.raises(NotImplementedError): + _ = dm.val_dataloader() + with pytest.raises(NotImplementedError): + _ = dm.test_dataloader() valid_ds = ds() test_ds = ds() @@ -504,7 +508,8 @@ def test_dm_init_from_datasets_dataloaders(iterable): dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) dm.test_dataloader() dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) - assert dm.train_dataloader() is None + with pytest.raises(NotImplementedError): + _ = dm.train_dataloader() valid_dss = [ds(), ds()] test_dss = [ds(), ds()]