From 8d6e2ff7b2b2db43797fb4cee52bb9bd6158e7f5 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Fri, 21 May 2021 09:03:16 -0700 Subject: [PATCH] Improve argument validation for validate(), test(), and predict() (#7605) Co-authored-by: Yifu Wang --- pytorch_lightning/trainer/trainer.py | 12 ++++++++++++ tests/trainer/test_trainer.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e9e5bfb185..6a20625978 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -502,6 +502,10 @@ class Trainer( model_provided = model is not None model = model or self.lightning_module + if model is None: + raise MisconfigurationException( + "`model` must be provided to `trainer.validate()` when it hasn't been passed in a previous run" + ) # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) @@ -562,6 +566,10 @@ class Trainer( model_provided = model is not None model = model or self.lightning_module + if model is None: + raise MisconfigurationException( + "`model` must be provided to `trainer.test()` when it hasn't been passed in a previous run" + ) # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) @@ -624,6 +632,10 @@ class Trainer( model_provided = model is not None model = model or self.lightning_module + if model is None: + raise MisconfigurationException( + "`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run" + ) # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c878cb9db1..2d0b68b1a6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1927,3 +1927,14 @@ def test_module_current_fx_attributes_reset(tmpdir): trainer.test(model) assert model._current_fx_name is None assert model._current_dataloader_idx is None + + +def test_exception_when_lightning_module_is_not_set_on_trainer(): + trainer = Trainer() + + with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*validate"): + trainer.validate() + with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*test"): + trainer.test() + with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*predict"): + trainer.predict()