82 lines
2.5 KiB
Python
Executable File
82 lines
2.5 KiB
Python
Executable File
import pytest
|
|
|
|
import tests.base.develop_utils as tutils
|
|
from pytorch_lightning import Trainer, LightningModule
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
# TODO: add matching messages
|
|
|
|
|
|
def test_wrong_train_setting(tmpdir):
|
|
"""
|
|
* Test that an error is thrown when no `train_dataloader()` is defined
|
|
* Test that an error is thrown when no `training_step()` is defined
|
|
"""
|
|
tutils.reset_seed()
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.train_dataloader = None
|
|
trainer.fit(model)
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.training_step = None
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_wrong_configure_optimizers(tmpdir):
|
|
""" Test that an error is thrown when no `configure_optimizers()` is defined """
|
|
tutils.reset_seed()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
model = EvalModelTemplate()
|
|
model.configure_optimizers = None
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_val_loop_config(tmpdir):
|
|
""""
|
|
When either val loop or val data are missing raise warning
|
|
"""
|
|
tutils.reset_seed()
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
|
|
# no val data has val loop
|
|
with pytest.warns(UserWarning):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.validation_step = None
|
|
trainer.fit(model)
|
|
|
|
# has val loop but no val data
|
|
with pytest.warns(UserWarning):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.val_dataloader = None
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_test_loop_config(tmpdir):
|
|
""""
|
|
When either test loop or test data are missing
|
|
"""
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
|
|
# has test loop but no test data
|
|
with pytest.warns(UserWarning):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.test_dataloader = None
|
|
trainer.test(model)
|
|
|
|
# has test data but no test loop
|
|
with pytest.warns(UserWarning):
|
|
model = EvalModelTemplate(**hparams)
|
|
model.test_step = None
|
|
trainer.test(model, test_dataloaders=model.dataloader(train=False))
|