Move `training_output` validation to after `train_step_end` (#7868)
* move validation to after aggregation * changelog * add test for training_step_end * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3427cb728d
commit
f9fccdfb39
|
@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
|
||||
|
||||
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
|
||||
|
||||
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
|
||||
|
|
|
@ -296,10 +296,10 @@ class TrainLoop:
|
|||
|
||||
self.trainer.logger_connector.cache_logged_metrics()
|
||||
|
||||
self._check_training_step_output(training_step_output)
|
||||
|
||||
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
||||
|
||||
self._check_training_step_output(training_step_output)
|
||||
|
||||
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
|
||||
training_step_output, split_batch
|
||||
)
|
||||
|
|
|
@ -150,12 +150,12 @@ def test_should_stop_mid_epoch(tmpdir):
|
|||
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
|
||||
def test_warning_invalid_trainstep_output(tmpdir, output):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
class InvalidTrainStepModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
model = InvalidTrainStepModel()
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
with pytest.raises(
|
||||
|
@ -166,3 +166,22 @@ def test_warning_invalid_trainstep_output(tmpdir, output):
|
|||
)
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_warning_valid_train_step_end(tmpdir):
|
||||
|
||||
class ValidTrainStepEndModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self(batch)
|
||||
return {'output': output, 'batch': batch}
|
||||
|
||||
def training_step_end(self, outputs):
|
||||
loss = self.loss(outputs['batch'], outputs['output'])
|
||||
return loss
|
||||
|
||||
# No error is raised
|
||||
model = ValidTrainStepEndModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue