diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py index 603bc9147e..4d03035b46 100755 --- a/tests/trainer/test_checks.py +++ b/tests/trainer/test_checks.py @@ -121,3 +121,14 @@ def test_wrong_test_settigs(tmpdir): model.test_dataloader = LightningModule.test_dataloader model.test_epoch_end = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) + + # ---------------- + # if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step + # ---------------- + model = EvalModelTemplate(hparams) + model.test_dataloader = LightningModule.test_dataloader + model.train_dataloader = None + model.train_step = None + model.val_dataloader = None + model.val_step = None + trainer.test(model, test_dataloaders=model.dataloader(train=False))