add test for trainer.test() (#1858)
* fix trainer.test() * Update trainer.py Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
d7f9c03663
commit
1a797bdad5
|
@ -121,3 +121,14 @@ def test_wrong_test_settigs(tmpdir):
|
|||
model.test_dataloader = LightningModule.test_dataloader
|
||||
model.test_epoch_end = None
|
||||
trainer.test(model, test_dataloaders=model.dataloader(train=False))
|
||||
|
||||
# ----------------
|
||||
# if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step
|
||||
# ----------------
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_dataloader = LightningModule.test_dataloader
|
||||
model.train_dataloader = None
|
||||
model.train_step = None
|
||||
model.val_dataloader = None
|
||||
model.val_step = None
|
||||
trainer.test(model, test_dataloaders=model.dataloader(train=False))
|
||||
|
|
Loading…
Reference in New Issue