Improve argument validation for validate(), test(), and predict() (#7605)

Co-authored-by: Yifu Wang <yifuwang@2012@gmail.com>
This commit is contained in:
Yifu Wang 2021-05-21 09:03:16 -07:00 committed by GitHub
parent e16d4fbdee
commit 8d6e2ff7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -502,6 +502,10 @@ class Trainer(
model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
"`model` must be provided to `trainer.validate()` when it hasn't been passed in a previous run"
)
# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule)
@ -562,6 +566,10 @@ class Trainer(
model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
"`model` must be provided to `trainer.test()` when it hasn't been passed in a previous run"
)
# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule)
@ -624,6 +632,10 @@ class Trainer(
model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
"`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run"
)
# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

View File

@ -1927,3 +1927,14 @@ def test_module_current_fx_attributes_reset(tmpdir):
trainer.test(model)
assert model._current_fx_name is None
assert model._current_dataloader_idx is None
def test_exception_when_lightning_module_is_not_set_on_trainer():
trainer = Trainer()
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*validate"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*test"):
trainer.test()
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*predict"):
trainer.predict()