diff --git a/CHANGELOG.md b/CHANGELOG.md index d3b4a38dd6..40d05e36b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) + * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5573b04952..1a19c753b0 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -20,8 +20,7 @@ from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import BaseProgress, Progress -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException T = TypeVar("T") # the output type of `run` @@ -200,25 +199,19 @@ class Loop(ABC, Generic[T]): self, state_dict: Dict, prefix: str = "", - restart_progress: bool = True, metrics: Optional[Dict[str, Metric]] = None, ) -> None: """Loads the state of this loop and all its children.""" - self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics) + self._load_from_state_dict(state_dict.copy(), prefix, metrics) for k, v in self.__dict__.items(): if isinstance(v, Loop): - v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) + v.load_state_dict(state_dict.copy(), prefix + k + ".") - def _load_from_state_dict( - self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None - ) -> None: + def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): key = prefix + k if isinstance(v, BaseProgress): v.load_state_dict(state_dict[key]) - if restart_progress: - apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) - elif ( isinstance(v, ResultCollection) and self.trainer is not None diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 6b5fecd07e..00a5ee32b9 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -57,6 +57,8 @@ class DataLoaderLoop(Loop): """Resets the internal state.""" if not self.restarting: self.dataloader_progress.current.reset() + else: + self.dataloader_progress.current.reset_on_restart() def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.dataloader_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index f30df960c1..3c9cf0d717 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -59,6 +59,8 @@ class EvaluationEpochLoop(Loop): if not self.restarting: self.batch_progress.current.reset() + else: + self.batch_progress.current.reset_on_restart() def on_run_start( self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 274074653d..3d7f36477c 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -93,6 +93,11 @@ class TrainingEpochLoop(loops.Loop): """Resets the internal state of the loop for a new run.""" assert self.batch_loop is not None assert self.batch_loop.optimizer_loop is not None + if self.restarting: + self.batch_progress.current.reset_on_restart() + self.scheduler_progress.current.reset_on_restart() + self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() + self.is_last_batch = False # track epoch output diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3e9917a551..9a4f7c510f 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -174,6 +174,8 @@ class FitLoop(Loop): def reset(self) -> None: """Resets the internal state of this loop.""" + if self.restarting: + self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 484c3ba4bd..a5a7ca7467 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -197,6 +197,8 @@ class OptimizerLoop(Loop): if not self.restarting: # when reset() is called from outside (manually), we reset the loop progress self.optim_progress.optimizer_position = 0 + else: + self.optim_progress.reset_on_restart() self.outputs = [[] for _ in range(len(self.trainer.optimizers))] def on_run_start( # type: ignore[override] diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 5b4f072305..0f07c61999 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -153,6 +153,9 @@ class Progress(BaseProgress): self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) + def reset_on_restart(self) -> None: + self.current.reset_on_restart() + @dataclass class DataLoaderProgress(Progress): @@ -229,3 +232,7 @@ class OptimizationProgress(BaseProgress): def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] + + def reset_on_restart(self) -> None: + self.optimizer.step.current.reset_on_restart() + self.optimizer.zero_grad.current.reset_on_restart() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 870c525561..47145a2f8f 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -297,7 +297,7 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_ } assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected - trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) + trainer.fit_loop.load_state_dict(checkpoint) # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches @@ -319,18 +319,6 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_ } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected - trainer.fit_loop.load_state_dict(checkpoint) - expected = { - "total": { - "ready": total_val_batch + 1, - "started": total_val_batch + 1, - "processed": total_val_batch, - "completed": total_val_batch, - }, - "current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch}, - } - assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected - @RunIf(min_torch="1.7.0") @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @@ -496,7 +484,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch } assert checkpoint["loops"]["fit_loop"] == expected - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the @@ -504,8 +492,14 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] - # with `restart_progress=True`, we expect all `ready` counters to be reset to `completed` - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True) + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) + # test resetting manually, we expect all `ready` counters to be reset to `completed` + trainer.fit_loop.reset() + trainer.fit_loop.epoch_loop.reset() + trainer.fit_loop.epoch_loop.batch_loop.reset() + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset() + trainer.fit_loop.epoch_loop.val_loop.reset() + trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress assert epoch_progress.current.ready == stop_epoch @@ -691,28 +685,26 @@ def test_fit_loop_reset(tmpdir): assert not epoch_loop.restarting assert not optimizer_loop.restarting + # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - - def mid_epoch_reset_assertions(): - assert fit_loop.restarting - assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 0 - assert fit_loop.epoch_progress.current.completed == 0 - - assert epoch_loop.restarting - assert epoch_loop.batch_progress.total.ready == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 2 - assert epoch_loop.batch_progress.current.completed == 2 - - # resetting from a mid-epoch checkpoint should not change progress counters - mid_epoch_reset_assertions() - assert optimizer_loop.optim_progress.optimizer_position == 1 + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() - mid_epoch_reset_assertions() + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.completed == 2 + + assert optimizer_loop.restarting assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch @@ -723,23 +715,9 @@ def test_fit_loop_reset(tmpdir): epoch_loop.restarting = False optimizer_loop.restarting = False + # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) - - assert fit_loop.restarting - assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 0 - assert fit_loop.epoch_progress.current.completed == 0 - - assert epoch_loop.restarting - assert epoch_loop.batch_progress.total.ready == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 4 - assert epoch_loop.batch_progress.current.completed == 4 - - assert optimizer_loop.optim_progress.optimizer_position == 1 - - # resetting from a end-of-epoch checkpoint should reset the current counters to 0 + # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset()