add test for trainer.test() (#1858)

* fix trainer.test()

* Update trainer.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Victor Quach 2020-05-17 16:30:20 -04:00 committed by GitHub
parent d7f9c03663
commit 1a797bdad5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 0 deletions

View File

@ -121,3 +121,14 @@ def test_wrong_test_settigs(tmpdir):
model.test_dataloader = LightningModule.test_dataloader
model.test_epoch_end = None
trainer.test(model, test_dataloaders=model.dataloader(train=False))
# ----------------
# if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step
# ----------------
model = EvalModelTemplate(hparams)
model.test_dataloader = LightningModule.test_dataloader
model.train_dataloader = None
model.train_step = None
model.val_dataloader = None
model.val_step = None
trainer.test(model, test_dataloaders=model.dataloader(train=False))