[Refactor] 1/2 Move reset_on_restart within the loop reset (#9561)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
8b9d977a08
commit
89ab2470c1
|
@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
|
||||
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
|
||||
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
|
||||
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
|
||||
|
||||
|
||||
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
|
||||
|
|
|
@ -20,8 +20,7 @@ from torchmetrics import Metric
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BaseProgress, Progress
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
T = TypeVar("T") # the output type of `run`
|
||||
|
@ -200,25 +199,19 @@ class Loop(ABC, Generic[T]):
|
|||
self,
|
||||
state_dict: Dict,
|
||||
prefix: str = "",
|
||||
restart_progress: bool = True,
|
||||
metrics: Optional[Dict[str, Metric]] = None,
|
||||
) -> None:
|
||||
"""Loads the state of this loop and all its children."""
|
||||
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
|
||||
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, Loop):
|
||||
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)
|
||||
v.load_state_dict(state_dict.copy(), prefix + k + ".")
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
|
||||
) -> None:
|
||||
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
|
||||
for k, v in self.__dict__.items():
|
||||
key = prefix + k
|
||||
if isinstance(v, BaseProgress):
|
||||
v.load_state_dict(state_dict[key])
|
||||
if restart_progress:
|
||||
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())
|
||||
|
||||
elif (
|
||||
isinstance(v, ResultCollection)
|
||||
and self.trainer is not None
|
||||
|
|
|
@ -57,6 +57,8 @@ class DataLoaderLoop(Loop):
|
|||
"""Resets the internal state."""
|
||||
if not self.restarting:
|
||||
self.dataloader_progress.current.reset()
|
||||
else:
|
||||
self.dataloader_progress.current.reset_on_restart()
|
||||
|
||||
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.dataloader_progress.increment_ready()
|
||||
|
|
|
@ -59,6 +59,8 @@ class EvaluationEpochLoop(Loop):
|
|||
|
||||
if not self.restarting:
|
||||
self.batch_progress.current.reset()
|
||||
else:
|
||||
self.batch_progress.current.reset_on_restart()
|
||||
|
||||
def on_run_start(
|
||||
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
|
||||
|
|
|
@ -93,6 +93,11 @@ class TrainingEpochLoop(loops.Loop):
|
|||
"""Resets the internal state of the loop for a new run."""
|
||||
assert self.batch_loop is not None
|
||||
assert self.batch_loop.optimizer_loop is not None
|
||||
if self.restarting:
|
||||
self.batch_progress.current.reset_on_restart()
|
||||
self.scheduler_progress.current.reset_on_restart()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
|
||||
|
||||
self.is_last_batch = False
|
||||
|
||||
# track epoch output
|
||||
|
|
|
@ -174,6 +174,8 @@ class FitLoop(Loop):
|
|||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of this loop."""
|
||||
if self.restarting:
|
||||
self.epoch_progress.current.reset_on_restart()
|
||||
|
||||
def on_run_start(self) -> None:
|
||||
"""Calls the ``on_train_start`` hook."""
|
||||
|
|
|
@ -197,6 +197,8 @@ class OptimizerLoop(Loop):
|
|||
if not self.restarting:
|
||||
# when reset() is called from outside (manually), we reset the loop progress
|
||||
self.optim_progress.optimizer_position = 0
|
||||
else:
|
||||
self.optim_progress.reset_on_restart()
|
||||
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
|
||||
|
||||
def on_run_start( # type: ignore[override]
|
||||
|
|
|
@ -153,6 +153,9 @@ class Progress(BaseProgress):
|
|||
self.total.load_state_dict(state_dict["total"])
|
||||
self.current.load_state_dict(state_dict["current"])
|
||||
|
||||
def reset_on_restart(self) -> None:
|
||||
self.current.reset_on_restart()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoaderProgress(Progress):
|
||||
|
@ -229,3 +232,7 @@ class OptimizationProgress(BaseProgress):
|
|||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
self.optimizer_position = state_dict["optimizer_position"]
|
||||
|
||||
def reset_on_restart(self) -> None:
|
||||
self.optimizer.step.current.reset_on_restart()
|
||||
self.optimizer.zero_grad.current.reset_on_restart()
|
||||
|
|
|
@ -297,7 +297,7 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_
|
|||
}
|
||||
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected
|
||||
|
||||
trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False)
|
||||
trainer.fit_loop.load_state_dict(checkpoint)
|
||||
|
||||
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
|
||||
nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches
|
||||
|
@ -319,18 +319,6 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_
|
|||
}
|
||||
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected
|
||||
|
||||
trainer.fit_loop.load_state_dict(checkpoint)
|
||||
expected = {
|
||||
"total": {
|
||||
"ready": total_val_batch + 1,
|
||||
"started": total_val_batch + 1,
|
||||
"processed": total_val_batch,
|
||||
"completed": total_val_batch,
|
||||
},
|
||||
"current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch},
|
||||
}
|
||||
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected
|
||||
|
||||
|
||||
@RunIf(min_torch="1.7.0")
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
|
@ -496,7 +484,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
}
|
||||
assert checkpoint["loops"]["fit_loop"] == expected
|
||||
|
||||
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
|
||||
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
|
||||
state_dict = trainer.fit_loop.state_dict()
|
||||
|
||||
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
|
||||
|
@ -504,8 +492,14 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
|
||||
assert state_dict == checkpoint["loops"]["fit_loop"]
|
||||
|
||||
# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
|
||||
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)
|
||||
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
|
||||
# test resetting manually, we expect all `ready` counters to be reset to `completed`
|
||||
trainer.fit_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.batch_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.val_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset()
|
||||
|
||||
epoch_progress = trainer.fit_loop.epoch_progress
|
||||
assert epoch_progress.current.ready == stop_epoch
|
||||
|
@ -691,28 +685,26 @@ def test_fit_loop_reset(tmpdir):
|
|||
assert not epoch_loop.restarting
|
||||
assert not optimizer_loop.restarting
|
||||
|
||||
# we load exactly what was saved - no reset yet
|
||||
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
|
||||
|
||||
def mid_epoch_reset_assertions():
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 2
|
||||
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 2
|
||||
assert epoch_loop.batch_progress.current.completed == 2
|
||||
|
||||
# resetting from a mid-epoch checkpoint should not change progress counters
|
||||
mid_epoch_reset_assertions()
|
||||
assert optimizer_loop.optim_progress.optimizer_position == 1
|
||||
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
|
||||
fit_loop.reset()
|
||||
epoch_loop.reset()
|
||||
optimizer_loop.reset()
|
||||
mid_epoch_reset_assertions()
|
||||
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 2
|
||||
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 2
|
||||
assert epoch_loop.batch_progress.current.completed == 2
|
||||
|
||||
assert optimizer_loop.restarting
|
||||
assert optimizer_loop.optim_progress.optimizer_position == 1
|
||||
|
||||
# reset state loaded from a checkpoint from the end of an epoch
|
||||
|
@ -723,23 +715,9 @@ def test_fit_loop_reset(tmpdir):
|
|||
epoch_loop.restarting = False
|
||||
optimizer_loop.restarting = False
|
||||
|
||||
# we load exactly what was saved - no reset yet
|
||||
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
|
||||
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 4
|
||||
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 4
|
||||
assert epoch_loop.batch_progress.current.completed == 4
|
||||
|
||||
assert optimizer_loop.optim_progress.optimizer_position == 1
|
||||
|
||||
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
|
||||
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
|
||||
fit_loop.reset()
|
||||
epoch_loop.reset()
|
||||
optimizer_loop.reset()
|
||||
|
|
Loading…
Reference in New Issue