Improve argument validation for validate(), test(), and predict() (#7605)
Co-authored-by: Yifu Wang <yifuwang@2012@gmail.com>
This commit is contained in:
parent
e16d4fbdee
commit
8d6e2ff7b2
|
@ -502,6 +502,10 @@ class Trainer(
|
|||
|
||||
model_provided = model is not None
|
||||
model = model or self.lightning_module
|
||||
if model is None:
|
||||
raise MisconfigurationException(
|
||||
"`model` must be provided to `trainer.validate()` when it hasn't been passed in a previous run"
|
||||
)
|
||||
|
||||
# links data to the trainer
|
||||
self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule)
|
||||
|
@ -562,6 +566,10 @@ class Trainer(
|
|||
|
||||
model_provided = model is not None
|
||||
model = model or self.lightning_module
|
||||
if model is None:
|
||||
raise MisconfigurationException(
|
||||
"`model` must be provided to `trainer.test()` when it hasn't been passed in a previous run"
|
||||
)
|
||||
|
||||
# links data to the trainer
|
||||
self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule)
|
||||
|
@ -624,6 +632,10 @@ class Trainer(
|
|||
|
||||
model_provided = model is not None
|
||||
model = model or self.lightning_module
|
||||
if model is None:
|
||||
raise MisconfigurationException(
|
||||
"`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run"
|
||||
)
|
||||
|
||||
# links data to the trainer
|
||||
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
|
||||
|
|
|
@ -1927,3 +1927,14 @@ def test_module_current_fx_attributes_reset(tmpdir):
|
|||
trainer.test(model)
|
||||
assert model._current_fx_name is None
|
||||
assert model._current_dataloader_idx is None
|
||||
|
||||
|
||||
def test_exception_when_lightning_module_is_not_set_on_trainer():
|
||||
trainer = Trainer()
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*validate"):
|
||||
trainer.validate()
|
||||
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*test"):
|
||||
trainer.test()
|
||||
with pytest.raises(MisconfigurationException, match=r"`model` must be provided.*predict"):
|
||||
trainer.predict()
|
||||
|
|
Loading…
Reference in New Issue