parent
6b667b1237
commit
ceec51d96c
|
@ -301,8 +301,8 @@ class TrainerEvaluationLoopMixin(ABC):
|
||||||
|
|
||||||
def run_evaluation(self, test=False):
|
def run_evaluation(self, test=False):
|
||||||
# when testing make sure user defined a test step
|
# when testing make sure user defined a test step
|
||||||
if test and not (self.is_overriden('test_step') or self.is_overriden('test_end')):
|
if test and not self.is_overriden('test_step'):
|
||||||
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
|
m = '''You called `.test()` without defining model's `.test_step()`.
|
||||||
Please define and try again'''
|
Please define and try again'''
|
||||||
raise MisconfigurationException(m)
|
raise MisconfigurationException(m)
|
||||||
|
|
||||||
|
|
|
@ -793,22 +793,12 @@ def test_testpass_overrides(tmpdir):
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.train_dataloader()
|
return self.train_dataloader()
|
||||||
|
|
||||||
class TestModelNoStep(LightningTestModelBase):
|
|
||||||
def test_end(self, outputs):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
|
||||||
return self.train_dataloader()
|
|
||||||
|
|
||||||
# Misconfig when neither test_step or test_end is implemented
|
# Misconfig when neither test_step or test_end is implemented
|
||||||
with pytest.raises(MisconfigurationException):
|
with pytest.raises(MisconfigurationException):
|
||||||
model = LightningTestModelBase(hparams)
|
model = LightningTestModelBase(hparams)
|
||||||
Trainer().test(model)
|
Trainer().test(model)
|
||||||
|
|
||||||
# No exceptions when one or both of test_step or test_end are implemented
|
# No exceptions when one or both of test_step or test_end are implemented
|
||||||
model = TestModelNoStep(hparams)
|
|
||||||
Trainer().test(model)
|
|
||||||
|
|
||||||
model = TestModelNoEnd(hparams)
|
model = TestModelNoEnd(hparams)
|
||||||
Trainer().test(model)
|
Trainer().test(model)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue