parent
bffc5347d2
commit
86dd318dcc
|
@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))
|
||||
|
||||
|
||||
- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
|
|
@ -88,6 +88,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
|
||||
hooks."""
|
||||
void(*args, **kwargs)
|
||||
|
||||
# hook
|
||||
self._on_evaluation_model_eval()
|
||||
self.trainer.lightning_module.zero_grad()
|
||||
|
@ -199,7 +200,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self.trainer.call_hook("on_validation_end", *args, **kwargs)
|
||||
|
||||
# reset any `torchmetrics.Metric` and the logger connector state
|
||||
self.trainer.logger_connector.reset(metrics=True)
|
||||
self.trainer.logger_connector.reset_results(metrics=True)
|
||||
|
||||
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
|
||||
|
|
|
@ -78,7 +78,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
self.epoch_batch_indices = []
|
||||
|
||||
def on_run_start(self) -> None:
|
||||
"""Calls ``on_predict_start`` hook."""
|
||||
"""Calls ``_on_predict_start`` hook."""
|
||||
self._on_predict_start()
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
@ -286,14 +286,15 @@ class LoggerConnector:
|
|||
is_first_batch = bool(self._batch_idx) + self._split_idx == 0
|
||||
return is_different_fx and is_first_batch
|
||||
|
||||
def reset(self, metrics: Optional[bool] = None) -> None:
|
||||
if self.trainer.sanity_checking:
|
||||
# reset metrics
|
||||
self._progress_bar_metrics = {}
|
||||
self._logged_metrics = {}
|
||||
self._callback_metrics = {}
|
||||
assert self.trainer._results is not None
|
||||
self.trainer._results.reset(metrics=metrics)
|
||||
def reset_metrics(self) -> None:
|
||||
self._progress_bar_metrics = {}
|
||||
self._logged_metrics = {}
|
||||
self._callback_metrics = {}
|
||||
|
||||
def reset_results(self, metrics: Optional[bool] = None) -> None:
|
||||
if self.trainer._results is not None:
|
||||
self.trainer._results.reset(metrics=metrics)
|
||||
|
||||
self._batch_idx = None
|
||||
self._split_idx = None
|
||||
self._current_fx = None
|
||||
|
|
|
@ -1022,6 +1022,11 @@ class Trainer(
|
|||
# ----------------------------
|
||||
# TRAIN
|
||||
# ----------------------------
|
||||
|
||||
# reset logger connector
|
||||
self.logger_connector.reset_results()
|
||||
self.logger_connector.reset_metrics()
|
||||
|
||||
# hook
|
||||
if self.state.fn == TrainerFn.FITTING:
|
||||
self.call_hook("on_fit_start")
|
||||
|
@ -1206,6 +1211,10 @@ class Trainer(
|
|||
stage = self.state.stage
|
||||
self.sanity_checking = True
|
||||
|
||||
# reset logger connector
|
||||
self.logger_connector.reset_results()
|
||||
self.logger_connector.reset_metrics()
|
||||
|
||||
self.call_hook("on_sanity_check_start")
|
||||
|
||||
# reload dataloaders
|
||||
|
@ -1217,8 +1226,9 @@ class Trainer(
|
|||
|
||||
self.call_hook("on_sanity_check_end")
|
||||
|
||||
# reset validation metrics
|
||||
self.logger_connector.reset()
|
||||
# reset logger connector
|
||||
self.logger_connector.reset_results()
|
||||
self.logger_connector.reset_metrics()
|
||||
|
||||
# reset the seed to what it was before sanity check
|
||||
# prevents sanity check to affect random sampling in training
|
||||
|
|
|
@ -536,6 +536,12 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
|||
# hp_metric + 2 steps + epoch + 2 steps + epoch
|
||||
expected_num_calls = 1 + 2 + 1 + 2 + 1
|
||||
|
||||
assert set(trainer.callback_metrics) == {
|
||||
"train_loss",
|
||||
"valid_loss_0_epoch",
|
||||
"valid_loss_0",
|
||||
"valid_loss_1",
|
||||
}
|
||||
assert len(mock_log_metrics.mock_calls) == expected_num_calls
|
||||
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
|
||||
|
||||
|
@ -569,10 +575,6 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
|||
|
||||
results = trainer.test(model)
|
||||
assert set(trainer.callback_metrics) == {
|
||||
"train_loss",
|
||||
"valid_loss_0_epoch",
|
||||
"valid_loss_0",
|
||||
"valid_loss_1",
|
||||
"test_loss",
|
||||
}
|
||||
assert set(results[0]) == {"test_loss"}
|
||||
|
|
|
@ -1950,3 +1950,49 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
|
|||
) as exception_hook:
|
||||
trainer.predict(model, model.val_dataloader(), return_predictions=False)
|
||||
exception_hook.assert_called()
|
||||
|
||||
|
||||
def test_trainer_metrics_reset_before_each_task(tmpdir):
|
||||
"""Test that callback, logged and progress bar metrics are reset before each task starts."""
|
||||
|
||||
class TestMetricRestartCallback(Callback):
|
||||
def _make_assertions(self, trainer):
|
||||
assert trainer.callback_metrics == {}
|
||||
assert trainer.progress_bar_metrics == {}
|
||||
assert trainer.logged_metrics == {}
|
||||
|
||||
def on_train_start(self, trainer, *args, **kwargs):
|
||||
self._make_assertions(trainer)
|
||||
|
||||
def on_validation_start(self, trainer, *args, **kwargs):
|
||||
if trainer.state.fn == TrainerFn.VALIDATING:
|
||||
self._make_assertions(trainer)
|
||||
|
||||
def on_test_start(self, trainer, *args, **kwargs):
|
||||
self._make_assertions(trainer)
|
||||
|
||||
def on_predict_start(self, trainer, *args, **kwargs):
|
||||
self._make_assertions(trainer)
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
self.log("train/metric", 7.0)
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
self.log("val/metric", 14.0)
|
||||
return super().validation_step(*args, **kwargs)
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
self.log("test/metric", 21.0)
|
||||
return super().test_step(*args, **kwargs)
|
||||
|
||||
model = CustomBoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()])
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
|
|
Loading…
Reference in New Issue