Move `training_output` validation to after `train_step_end` (#7868)

* move validation to after aggregation

* changelog

* add test for training_step_end

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Luis Perez 2021-06-08 01:37:50 -07:00 committed by GitHub
parent 3427cb728d
commit f9fccdfb39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 4 deletions

View File

@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))

View File

@ -296,10 +296,10 @@ class TrainLoop:
self.trainer.logger_connector.cache_logged_metrics()
self._check_training_step_output(training_step_output)
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
self._check_training_step_output(training_step_output)
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
training_step_output, split_batch
)

View File

@ -150,12 +150,12 @@ def test_should_stop_mid_epoch(tmpdir):
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
def test_warning_invalid_trainstep_output(tmpdir, output):
class TestModel(BoringModel):
class InvalidTrainStepModel(BoringModel):
def training_step(self, batch, batch_idx):
return output
model = TestModel()
model = InvalidTrainStepModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(
@ -166,3 +166,22 @@ def test_warning_invalid_trainstep_output(tmpdir, output):
)
):
trainer.fit(model)
def test_warning_valid_train_step_end(tmpdir):
class ValidTrainStepEndModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self(batch)
return {'output': output, 'batch': batch}
def training_step_end(self, outputs):
loss = self.loss(outputs['batch'], outputs['output'])
return loss
# No error is raised
model = ValidTrainStepEndModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.fit(model)