Reset metrics before each task starts (#9410)

* reset metrics
This commit is contained in:
Rohit Gupta 2021-09-23 13:50:25 +05:30 committed by GitHub
parent bffc5347d2
commit 86dd318dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 16 deletions

View File

@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))
- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410))
### Deprecated
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

View File

@ -88,6 +88,7 @@ class EvaluationLoop(DataLoaderLoop):
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)
# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
@ -199,7 +200,7 @@ class EvaluationLoop(DataLoaderLoop):
self.trainer.call_hook("on_validation_end", *args, **kwargs)
# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)
self.trainer.logger_connector.reset_results(metrics=True)
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""

View File

@ -78,7 +78,7 @@ class PredictionLoop(DataLoaderLoop):
self.epoch_batch_indices = []
def on_run_start(self) -> None:
"""Calls ``on_predict_start`` hook."""
"""Calls ``_on_predict_start`` hook."""
self._on_predict_start()
def advance(self, *args: Any, **kwargs: Any) -> None:

View File

@ -286,14 +286,15 @@ class LoggerConnector:
is_first_batch = bool(self._batch_idx) + self._split_idx == 0
return is_different_fx and is_first_batch
def reset(self, metrics: Optional[bool] = None) -> None:
if self.trainer.sanity_checking:
# reset metrics
self._progress_bar_metrics = {}
self._logged_metrics = {}
self._callback_metrics = {}
assert self.trainer._results is not None
self.trainer._results.reset(metrics=metrics)
def reset_metrics(self) -> None:
self._progress_bar_metrics = {}
self._logged_metrics = {}
self._callback_metrics = {}
def reset_results(self, metrics: Optional[bool] = None) -> None:
if self.trainer._results is not None:
self.trainer._results.reset(metrics=metrics)
self._batch_idx = None
self._split_idx = None
self._current_fx = None

View File

@ -1022,6 +1022,11 @@ class Trainer(
# ----------------------------
# TRAIN
# ----------------------------
# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
# hook
if self.state.fn == TrainerFn.FITTING:
self.call_hook("on_fit_start")
@ -1206,6 +1211,10 @@ class Trainer(
stage = self.state.stage
self.sanity_checking = True
# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
self.call_hook("on_sanity_check_start")
# reload dataloaders
@ -1217,8 +1226,9 @@ class Trainer(
self.call_hook("on_sanity_check_end")
# reset validation metrics
self.logger_connector.reset()
# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
# reset the seed to what it was before sanity check
# prevents sanity check to affect random sampling in training

View File

@ -536,6 +536,12 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
# hp_metric + 2 steps + epoch + 2 steps + epoch
expected_num_calls = 1 + 2 + 1 + 2 + 1
assert set(trainer.callback_metrics) == {
"train_loss",
"valid_loss_0_epoch",
"valid_loss_0",
"valid_loss_1",
}
assert len(mock_log_metrics.mock_calls) == expected_num_calls
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
@ -569,10 +575,6 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
results = trainer.test(model)
assert set(trainer.callback_metrics) == {
"train_loss",
"valid_loss_0_epoch",
"valid_loss_0",
"valid_loss_1",
"test_loss",
}
assert set(results[0]) == {"test_loss"}

View File

@ -1950,3 +1950,49 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
) as exception_hook:
trainer.predict(model, model.val_dataloader(), return_predictions=False)
exception_hook.assert_called()
def test_trainer_metrics_reset_before_each_task(tmpdir):
"""Test that callback, logged and progress bar metrics are reset before each task starts."""
class TestMetricRestartCallback(Callback):
def _make_assertions(self, trainer):
assert trainer.callback_metrics == {}
assert trainer.progress_bar_metrics == {}
assert trainer.logged_metrics == {}
def on_train_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)
def on_validation_start(self, trainer, *args, **kwargs):
if trainer.state.fn == TrainerFn.VALIDATING:
self._make_assertions(trainer)
def on_test_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)
def on_predict_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
def training_step(self, *args, **kwargs):
self.log("train/metric", 7.0)
return super().training_step(*args, **kwargs)
def validation_step(self, *args, **kwargs):
self.log("val/metric", 14.0)
return super().validation_step(*args, **kwargs)
def test_step(self, *args, **kwargs):
self.log("test/metric", 21.0)
return super().test_step(*args, **kwargs)
model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()])
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)