Use `global_step` while restoring logging step for old checkpoints (#13645)
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
6cbd9d7575
commit
c67b075cf5
|
@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
|
||||
|
||||
|
||||
- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))
|
||||
|
||||
|
||||
## [1.6.4] - 2022-06-01
|
||||
|
||||
### Added
|
||||
|
|
|
@ -287,7 +287,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
|
||||
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
|
||||
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
|
||||
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
|
||||
|
||||
def _run_validation(self) -> None:
|
||||
# reload dataloaders
|
||||
|
|
|
@ -264,10 +264,19 @@ class CheckpointConnector:
|
|||
return
|
||||
|
||||
fit_loop = self.trainer.fit_loop
|
||||
pl_module = self.trainer.lightning_module
|
||||
assert pl_module is not None
|
||||
|
||||
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
|
||||
# it will be overwritten by the loop's state if it was also saved
|
||||
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
|
||||
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
|
||||
batch_loop = fit_loop.epoch_loop.batch_loop
|
||||
if pl_module.automatic_optimization:
|
||||
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
|
||||
"global_step"
|
||||
]
|
||||
else:
|
||||
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
|
||||
|
||||
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
|
||||
# it will be overwritten by the loop's state if it was also saved
|
||||
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
|
||||
|
|
|
@ -28,7 +28,7 @@ import tests_pytorch.helpers.pipelines as tpipes
|
|||
import tests_pytorch.helpers.utils as tutils
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
@ -255,6 +255,7 @@ def test_correct_step_and_epoch(tmpdir):
|
|||
def on_train_start(self) -> None:
|
||||
assert self.trainer.current_epoch == first_max_epochs
|
||||
assert self.trainer.global_step == first_max_epochs * train_batches
|
||||
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches
|
||||
|
||||
trainer.fit(TestModel(), ckpt_path=ckpt_path)
|
||||
assert trainer.current_epoch == max_epochs
|
||||
|
@ -262,6 +263,29 @@ def test_correct_step_and_epoch(tmpdir):
|
|||
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
|
||||
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
|
||||
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
|
||||
model = model_class()
|
||||
trainer.fit(model)
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
ckpt = torch.load(ckpt_path)
|
||||
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
|
||||
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
|
||||
torch.save(ckpt, ckpt_path)
|
||||
|
||||
class TestModel(model_class):
|
||||
def on_train_start(self) -> None:
|
||||
assert self.trainer.global_step == 1
|
||||
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1
|
||||
|
||||
trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
|
||||
model = TestModel()
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
new_loop = trainer.fit_loop.epoch_loop
|
||||
assert new_loop.global_step == new_loop._batches_that_stepped == 2
|
||||
|
||||
|
||||
def test_fit_twice(tmpdir):
|
||||
epochs = []
|
||||
|
||||
|
|
|
@ -172,6 +172,8 @@ def test_loops_restore(tmpdir):
|
|||
ckpt_path = str(tmpdir / "last.ckpt")
|
||||
|
||||
trainer = Trainer(**trainer_args)
|
||||
trainer.strategy.connect(model)
|
||||
|
||||
for fn in TrainerFn:
|
||||
if fn != TrainerFn.TUNING:
|
||||
trainer_fn = getattr(trainer, f"{fn}_loop")
|
||||
|
|
Loading…
Reference in New Issue