Fix for incorrect run on the validation set with overwritten validation_epoch_end and test_end (#1353)

* reorder if clauses

* fix wrong method overload in test

* fix formatting

* update change_log

* fix line too long
This commit is contained in:
Adrian Wälchli 2020-04-03 15:25:32 +02:00 committed by GitHub
parent 868b172f05
commit ebd9fc9530
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 21 deletions

View File

@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).
## [0.7.1] - 2020-03-07

View File

@ -295,20 +295,25 @@ class TrainerEvaluationLoopMixin(ABC):
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module
# TODO: remove in v1.0.0
if test_mode and self.is_overriden('test_end', model=model):
eval_results = model.test_end(outputs)
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)
elif self.is_overriden('validation_end', model=model):
eval_results = model.validation_end(outputs)
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)
if test_mode:
if self.is_overriden('test_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.test_end(outputs)
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)
if test_mode and self.is_overriden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
elif self.is_overriden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
elif self.is_overriden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
else:
if self.is_overriden('validation_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.validation_end(outputs)
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)
elif self.is_overriden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
# enable train mode again
model.train()

View File

@ -528,15 +528,15 @@ def test_disabled_validation():
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
validation_step_invoked = False
validation_end_invoked = False
validation_epoch_end_invoked = False
def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)
def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)
def validation_epoch_end(self, *args, **kwargs):
self.validation_epoch_end_invoked = True
return super().validation_epoch_end(*args, **kwargs)
hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)
@ -555,8 +555,10 @@ def test_disabled_validation():
# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
assert not model.validation_step_invoked, \
'`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_epoch_end_invoked, \
'`validation_epoch_end` should not run when `val_percent_check=0`'
# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
@ -566,8 +568,10 @@ def test_disabled_validation():
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
assert model.validation_step_invoked, \
'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_epoch_end_invoked, \
'did not run `validation_epoch_end` with `fast_dev_run=True`'
def test_nan_loss_detection(tmpdir):