Fix for incorrect run on the validation set with overwritten validation_epoch_end and test_end (#1353)
* reorder if clauses * fix wrong method overload in test * fix formatting * update change_log * fix line too long
This commit is contained in:
parent
868b172f05
commit
ebd9fc9530
|
@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
|
||||
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
|
||||
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
|
||||
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).
|
||||
|
||||
## [0.7.1] - 2020-03-07
|
||||
|
||||
|
|
|
@ -295,20 +295,25 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
|
||||
model = model.module
|
||||
|
||||
# TODO: remove in v1.0.0
|
||||
if test_mode and self.is_overriden('test_end', model=model):
|
||||
eval_results = model.test_end(outputs)
|
||||
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `test_epoch_end` instead.', DeprecationWarning)
|
||||
elif self.is_overriden('validation_end', model=model):
|
||||
eval_results = model.validation_end(outputs)
|
||||
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `validation_epoch_end` instead.', DeprecationWarning)
|
||||
if test_mode:
|
||||
if self.is_overriden('test_end', model=model):
|
||||
# TODO: remove in v1.0.0
|
||||
eval_results = model.test_end(outputs)
|
||||
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `test_epoch_end` instead.', DeprecationWarning)
|
||||
|
||||
if test_mode and self.is_overriden('test_epoch_end', model=model):
|
||||
eval_results = model.test_epoch_end(outputs)
|
||||
elif self.is_overriden('validation_epoch_end', model=model):
|
||||
eval_results = model.validation_epoch_end(outputs)
|
||||
elif self.is_overriden('test_epoch_end', model=model):
|
||||
eval_results = model.test_epoch_end(outputs)
|
||||
|
||||
else:
|
||||
if self.is_overriden('validation_end', model=model):
|
||||
# TODO: remove in v1.0.0
|
||||
eval_results = model.validation_end(outputs)
|
||||
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `validation_epoch_end` instead.', DeprecationWarning)
|
||||
|
||||
elif self.is_overriden('validation_epoch_end', model=model):
|
||||
eval_results = model.validation_epoch_end(outputs)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
|
|
@ -528,15 +528,15 @@ def test_disabled_validation():
|
|||
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
|
||||
|
||||
validation_step_invoked = False
|
||||
validation_end_invoked = False
|
||||
validation_epoch_end_invoked = False
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
self.validation_step_invoked = True
|
||||
return super().validation_step(*args, **kwargs)
|
||||
|
||||
def validation_end(self, *args, **kwargs):
|
||||
self.validation_end_invoked = True
|
||||
return super().validation_end(*args, **kwargs)
|
||||
def validation_epoch_end(self, *args, **kwargs):
|
||||
self.validation_epoch_end_invoked = True
|
||||
return super().validation_epoch_end(*args, **kwargs)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentModel(hparams)
|
||||
|
@ -555,8 +555,10 @@ def test_disabled_validation():
|
|||
# check that val_percent_check=0 turns off validation
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 1
|
||||
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
|
||||
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
|
||||
assert not model.validation_step_invoked, \
|
||||
'`validation_step` should not run when `val_percent_check=0`'
|
||||
assert not model.validation_epoch_end_invoked, \
|
||||
'`validation_epoch_end` should not run when `val_percent_check=0`'
|
||||
|
||||
# check that val_percent_check has no influence when fast_dev_run is turned on
|
||||
model = CurrentModel(hparams)
|
||||
|
@ -566,8 +568,10 @@ def test_disabled_validation():
|
|||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 0
|
||||
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
|
||||
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
|
||||
assert model.validation_step_invoked, \
|
||||
'did not run `validation_step` with `fast_dev_run=True`'
|
||||
assert model.validation_epoch_end_invoked, \
|
||||
'did not run `validation_epoch_end` with `fast_dev_run=True`'
|
||||
|
||||
|
||||
def test_nan_loss_detection(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue