Use `global_step` while restoring logging step for old checkpoints (#13645)

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
Rohit Gupta 2022-07-20 00:23:22 +05:30 committed by GitHub
parent 6cbd9d7575
commit c67b075cf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 4 deletions

View File

@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))
## [1.6.4] - 2022-06-01
### Added

View File

@ -287,7 +287,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
def _run_validation(self) -> None:
# reload dataloaders

View File

@ -264,10 +264,19 @@ class CheckpointConnector:
return
fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

View File

@ -28,7 +28,7 @@ import tests_pytorch.helpers.pipelines as tpipes
import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
@ -255,6 +255,7 @@ def test_correct_step_and_epoch(tmpdir):
def on_train_start(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
assert self.trainer.global_step == first_max_epochs * train_batches
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches
trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
@ -262,6 +263,29 @@ def test_correct_step_and_epoch(tmpdir):
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(ckpt_path)
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
torch.save(ckpt, ckpt_path)
class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1
trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2
def test_fit_twice(tmpdir):
epochs = []

View File

@ -172,6 +172,8 @@ def test_loops_restore(tmpdir):
ckpt_path = str(tmpdir / "last.ckpt")
trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)
for fn in TrainerFn:
if fn != TrainerFn.TUNING:
trainer_fn = getattr(trainer, f"{fn}_loop")