parent
6b667b1237
commit
ceec51d96c
|
@ -301,8 +301,8 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
def run_evaluation(self, test=False):
|
||||
# when testing make sure user defined a test step
|
||||
if test and not (self.is_overriden('test_step') or self.is_overriden('test_end')):
|
||||
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
|
||||
if test and not self.is_overriden('test_step'):
|
||||
m = '''You called `.test()` without defining model's `.test_step()`.
|
||||
Please define and try again'''
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
|
|
|
@ -793,22 +793,12 @@ def test_testpass_overrides(tmpdir):
|
|||
def test_dataloader(self):
|
||||
return self.train_dataloader()
|
||||
|
||||
class TestModelNoStep(LightningTestModelBase):
|
||||
def test_end(self, outputs):
|
||||
return {}
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.train_dataloader()
|
||||
|
||||
# Misconfig when neither test_step or test_end is implemented
|
||||
with pytest.raises(MisconfigurationException):
|
||||
model = LightningTestModelBase(hparams)
|
||||
Trainer().test(model)
|
||||
|
||||
# No exceptions when one or both of test_step or test_end are implemented
|
||||
model = TestModelNoStep(hparams)
|
||||
Trainer().test(model)
|
||||
|
||||
model = TestModelNoEnd(hparams)
|
||||
Trainer().test(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue