Restore log step during restart (#13467)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2022-07-12 15:15:59 +05:30 committed by GitHub
parent 24189c2e9f
commit df931e2486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 1 deletions

View File

@ -311,6 +311,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417))
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
## [1.6.4] - 2022-06-01

View File

@ -273,6 +273,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped
if (
self.trainer is not None
@ -292,6 +293,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
def _run_validation(self) -> None:
# reload dataloaders

View File

@ -47,7 +47,7 @@ def test_loops_state_dict_structure():
expected = {
"fit_loop": {
"state_dict": {},
"epoch_loop.state_dict": {},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},

View File

@ -259,6 +259,7 @@ def test_correct_step_and_epoch(tmpdir):
trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * train_batches
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
def test_fit_twice(tmpdir):