diff --git a/CHANGELOG.md b/CHANGELOG.md index e4d42d1500..112b5170cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496)) +- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 69033751eb..4f58889b42 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -88,6 +88,7 @@ class EvaluationLoop(DataLoaderLoop): """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" void(*args, **kwargs) + # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() @@ -199,7 +200,7 @@ class EvaluationLoop(DataLoaderLoop): self.trainer.call_hook("on_validation_end", *args, **kwargs) # reset any `torchmetrics.Metric` and the logger connector state - self.trainer.logger_connector.reset(metrics=True) + self.trainer.logger_connector.reset_results(metrics=True) def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks.""" diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 212dc5c0e9..d4a6ab6d29 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -78,7 +78,7 @@ class PredictionLoop(DataLoaderLoop): self.epoch_batch_indices = [] def on_run_start(self) -> None: - """Calls ``on_predict_start`` hook.""" + """Calls ``_on_predict_start`` hook.""" self._on_predict_start() def advance(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2e6b607784..ad6a84d64b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -286,14 +286,15 @@ class LoggerConnector: is_first_batch = bool(self._batch_idx) + self._split_idx == 0 return is_different_fx and is_first_batch - def reset(self, metrics: Optional[bool] = None) -> None: - if self.trainer.sanity_checking: - # reset metrics - self._progress_bar_metrics = {} - self._logged_metrics = {} - self._callback_metrics = {} - assert self.trainer._results is not None - self.trainer._results.reset(metrics=metrics) + def reset_metrics(self) -> None: + self._progress_bar_metrics = {} + self._logged_metrics = {} + self._callback_metrics = {} + + def reset_results(self, metrics: Optional[bool] = None) -> None: + if self.trainer._results is not None: + self.trainer._results.reset(metrics=metrics) + self._batch_idx = None self._split_idx = None self._current_fx = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2e115decf3..3b64724e14 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1022,6 +1022,11 @@ class Trainer( # ---------------------------- # TRAIN # ---------------------------- + + # reset logger connector + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() + # hook if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_start") @@ -1206,6 +1211,10 @@ class Trainer( stage = self.state.stage self.sanity_checking = True + # reset logger connector + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() + self.call_hook("on_sanity_check_start") # reload dataloaders @@ -1217,8 +1226,9 @@ class Trainer( self.call_hook("on_sanity_check_end") - # reset validation metrics - self.logger_connector.reset() + # reset logger connector + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index e8b398bee8..7b94cfe970 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -536,6 +536,12 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): # hp_metric + 2 steps + epoch + 2 steps + epoch expected_num_calls = 1 + 2 + 1 + 2 + 1 + assert set(trainer.callback_metrics) == { + "train_loss", + "valid_loss_0_epoch", + "valid_loss_0", + "valid_loss_1", + } assert len(mock_log_metrics.mock_calls) == expected_num_calls assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0) @@ -569,10 +575,6 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): results = trainer.test(model) assert set(trainer.callback_metrics) == { - "train_loss", - "valid_loss_0_epoch", - "valid_loss_0", - "valid_loss_1", "test_loss", } assert set(results[0]) == {"test_loss"} diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f1bdd1f34..7d565edb00 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1950,3 +1950,49 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes): ) as exception_hook: trainer.predict(model, model.val_dataloader(), return_predictions=False) exception_hook.assert_called() + + +def test_trainer_metrics_reset_before_each_task(tmpdir): + """Test that callback, logged and progress bar metrics are reset before each task starts.""" + + class TestMetricRestartCallback(Callback): + def _make_assertions(self, trainer): + assert trainer.callback_metrics == {} + assert trainer.progress_bar_metrics == {} + assert trainer.logged_metrics == {} + + def on_train_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + def on_validation_start(self, trainer, *args, **kwargs): + if trainer.state.fn == TrainerFn.VALIDATING: + self._make_assertions(trainer) + + def on_test_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + def on_predict_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + + def training_step(self, *args, **kwargs): + self.log("train/metric", 7.0) + return super().training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + self.log("val/metric", 14.0) + return super().validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + self.log("test/metric", 21.0) + return super().test_step(*args, **kwargs) + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()]) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model)