Reset the current progress tracking state during double evaluation (#11119)
This commit is contained in:
parent
92d9fc2280
commit
75d96d9897
|
@ -302,6 +302,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))
|
||||
|
||||
|
||||
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
|
||||
|
||||
## [1.5.6] - 2021-12-15
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -337,6 +337,4 @@ class Loop(ABC, Generic[T]):
|
|||
v.reset(metrics=False)
|
||||
|
||||
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
|
||||
|
||||
if _FaultTolerantMode.detect_current_mode().is_enabled:
|
||||
self.restarting = True
|
||||
self.restarting = True
|
||||
|
|
|
@ -84,6 +84,10 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._max_batches = [self._max_batches] * len(self.dataloaders)
|
||||
|
||||
super().reset()
|
||||
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
|
||||
# need to reset the current state when the loop has finished running
|
||||
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
|
||||
self.dataloader_progress.reset_on_run()
|
||||
|
||||
def on_skip(self) -> List:
|
||||
return []
|
||||
|
|
|
@ -70,10 +70,15 @@ class PredictionLoop(DataLoaderLoop):
|
|||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of the loop for a new run."""
|
||||
super().reset()
|
||||
self.predictions = []
|
||||
self.epoch_batch_indices = []
|
||||
|
||||
super().reset()
|
||||
# when restarting, if we are running twice, since there's no concept of `max_epochs` we need to reset the
|
||||
# current state when the loop has finished running
|
||||
if self.done:
|
||||
self.dataloader_progress.reset_on_run()
|
||||
|
||||
def on_run_start(self) -> None: # type: ignore[override]
|
||||
"""Calls ``_on_predict_start`` hook."""
|
||||
self._on_predict_start()
|
||||
|
|
|
@ -22,6 +22,7 @@ from deprecate import void
|
|||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.utilities import _update_dataloader_iter
|
||||
from pytorch_lightning.trainer.progress import BatchProgress
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_collect_states_on_rank_zero_over_collection,
|
||||
|
@ -67,6 +68,10 @@ class EvaluationEpochLoop(Loop):
|
|||
self.batch_progress.reset_on_run()
|
||||
else:
|
||||
self.batch_progress.reset_on_restart()
|
||||
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
|
||||
# need to reset the current state when the loop has finished running
|
||||
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
|
||||
self.batch_progress.reset_on_run()
|
||||
|
||||
def on_run_start( # type: ignore[override]
|
||||
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int
|
||||
|
|
Loading…
Reference in New Issue