Reset the current progress tracking state during double evaluation (#11119)

This commit is contained in:
Carlos Mocholí 2021-12-17 19:20:11 +01:00 committed by GitHub
parent 92d9fc2280
commit 75d96d9897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 4 deletions

View File

@ -302,6 +302,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
## [1.5.6] - 2021-12-15
### Fixed

View File

@ -337,6 +337,4 @@ class Loop(ABC, Generic[T]):
v.reset(metrics=False)
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
if _FaultTolerantMode.detect_current_mode().is_enabled:
self.restarting = True
self.restarting = True

View File

@ -84,6 +84,10 @@ class EvaluationLoop(DataLoaderLoop):
self._max_batches = [self._max_batches] * len(self.dataloaders)
super().reset()
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
# need to reset the current state when the loop has finished running
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.dataloader_progress.reset_on_run()
def on_skip(self) -> List:
return []

View File

@ -70,10 +70,15 @@ class PredictionLoop(DataLoaderLoop):
def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
super().reset()
self.predictions = []
self.epoch_batch_indices = []
super().reset()
# when restarting, if we are running twice, since there's no concept of `max_epochs` we need to reset the
# current state when the loop has finished running
if self.done:
self.dataloader_progress.reset_on_run()
def on_run_start(self) -> None: # type: ignore[override]
"""Calls ``_on_predict_start`` hook."""
self._on_predict_start()

View File

@ -22,6 +22,7 @@ from deprecate import void
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import (
_collect_states_on_rank_zero_over_collection,
@ -67,6 +68,10 @@ class EvaluationEpochLoop(Loop):
self.batch_progress.reset_on_run()
else:
self.batch_progress.reset_on_restart()
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
# need to reset the current state when the loop has finished running
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.batch_progress.reset_on_run()
def on_run_start( # type: ignore[override]
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int