[Refactor] 1/2 Move reset_on_restart within the loop reset (#9561)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-09-17 17:11:32 +01:00 committed by GitHub
parent 8b9d977a08
commit 89ab2470c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 61 deletions

View File

@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

View File

@ -20,8 +20,7 @@ from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
T = TypeVar("T") # the output type of `run`
@ -200,25 +199,19 @@ class Loop(ABC, Generic[T]):
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)
v.load_state_dict(state_dict.copy(), prefix + k + ".")
def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())
elif (
isinstance(v, ResultCollection)
and self.trainer is not None

View File

@ -57,6 +57,8 @@ class DataLoaderLoop(Loop):
"""Resets the internal state."""
if not self.restarting:
self.dataloader_progress.current.reset()
else:
self.dataloader_progress.current.reset_on_restart()
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()

View File

@ -59,6 +59,8 @@ class EvaluationEpochLoop(Loop):
if not self.restarting:
self.batch_progress.current.reset()
else:
self.batch_progress.current.reset_on_restart()
def on_run_start(
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int

View File

@ -93,6 +93,11 @@ class TrainingEpochLoop(loops.Loop):
"""Resets the internal state of the loop for a new run."""
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.current.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
self.is_last_batch = False
# track epoch output

View File

@ -174,6 +174,8 @@ class FitLoop(Loop):
def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
self.epoch_progress.current.reset_on_restart()
def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""

View File

@ -197,6 +197,8 @@ class OptimizerLoop(Loop):
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
self.optim_progress.optimizer_position = 0
else:
self.optim_progress.reset_on_restart()
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
def on_run_start( # type: ignore[override]

View File

@ -153,6 +153,9 @@ class Progress(BaseProgress):
self.total.load_state_dict(state_dict["total"])
self.current.load_state_dict(state_dict["current"])
def reset_on_restart(self) -> None:
self.current.reset_on_restart()
@dataclass
class DataLoaderProgress(Progress):
@ -229,3 +232,7 @@ class OptimizationProgress(BaseProgress):
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])
self.optimizer_position = state_dict["optimizer_position"]
def reset_on_restart(self) -> None:
self.optimizer.step.current.reset_on_restart()
self.optimizer.zero_grad.current.reset_on_restart()

View File

@ -297,7 +297,7 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_
}
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected
trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint)
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches
@ -319,18 +319,6 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected
trainer.fit_loop.load_state_dict(checkpoint)
expected = {
"total": {
"ready": total_val_batch + 1,
"started": total_val_batch + 1,
"processed": total_val_batch,
"completed": total_val_batch,
},
"current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch},
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected
@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@ -496,7 +484,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
}
assert checkpoint["loops"]["fit_loop"] == expected
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
@ -504,8 +492,14 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
assert state_dict == checkpoint["loops"]["fit_loop"]
# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
# test resetting manually, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.reset()
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset()
epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
@ -691,28 +685,26 @@ def test_fit_loop_reset(tmpdir):
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
# we load exactly what was saved - no reset yet
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2
# resetting from a mid-epoch checkpoint should not change progress counters
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_position == 1
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2
assert optimizer_loop.restarting
assert optimizer_loop.optim_progress.optimizer_position == 1
# reset state loaded from a checkpoint from the end of an epoch
@ -723,23 +715,9 @@ def test_fit_loop_reset(tmpdir):
epoch_loop.restarting = False
optimizer_loop.restarting = False
# we load exactly what was saved - no reset yet
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4
assert optimizer_loop.optim_progress.optimizer_position == 1
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()