diff --git a/CHANGELOG.md b/CHANGELOG.md index c61a98bcf8..2613ce0364 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) + - Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 81498cbe3b..0e378bf86e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -296,10 +296,10 @@ class TrainLoop: self.trainer.logger_connector.cache_logged_metrics() - self._check_training_step_output(training_step_output) - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + self._check_training_step_output(training_step_output) + training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch ) diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index a2706e5d37..193399473d 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -150,12 +150,12 @@ def test_should_stop_mid_epoch(tmpdir): @pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )]) def test_warning_invalid_trainstep_output(tmpdir, output): - class TestModel(BoringModel): + class InvalidTrainStepModel(BoringModel): def training_step(self, batch, batch_idx): return output - model = TestModel() + model = InvalidTrainStepModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) with pytest.raises( @@ -166,3 +166,22 @@ def test_warning_invalid_trainstep_output(tmpdir, output): ) ): trainer.fit(model) + + +def test_warning_valid_train_step_end(tmpdir): + + class ValidTrainStepEndModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self(batch) + return {'output': output, 'batch': batch} + + def training_step_end(self, outputs): + loss = self.loss(outputs['batch'], outputs['output']) + return loss + + # No error is raised + model = ValidTrainStepEndModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + + trainer.fit(model)