Restore log step during restart (#13467)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
24189c2e9f
commit
df931e2486
|
@ -311,6 +311,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417))
|
||||
|
||||
|
||||
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
|
||||
|
||||
|
||||
## [1.6.4] - 2022-06-01
|
||||
|
||||
|
|
|
@ -273,6 +273,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
state_dict = super().on_save_checkpoint()
|
||||
state_dict["_batches_that_stepped"] = self._batches_that_stepped
|
||||
|
||||
if (
|
||||
self.trainer is not None
|
||||
|
@ -292,6 +293,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
|
||||
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
|
||||
|
||||
def _run_validation(self) -> None:
|
||||
# reload dataloaders
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_loops_state_dict_structure():
|
|||
expected = {
|
||||
"fit_loop": {
|
||||
"state_dict": {},
|
||||
"epoch_loop.state_dict": {},
|
||||
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
|
||||
"epoch_loop.batch_progress": {
|
||||
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
|
||||
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
|
||||
|
|
|
@ -259,6 +259,7 @@ def test_correct_step_and_epoch(tmpdir):
|
|||
trainer.fit(TestModel(), ckpt_path=ckpt_path)
|
||||
assert trainer.current_epoch == max_epochs
|
||||
assert trainer.global_step == max_epochs * train_batches
|
||||
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
|
||||
|
||||
|
||||
def test_fit_twice(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue