diff --git a/CHANGELOG.md b/CHANGELOG.md index f57f792b97..421f481c50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -819,6 +819,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an issue where `ModelCheckpoint` could delete last checkpoint from the old directory when `dirpath` has changed during resumed training ([#12225](https://github.com/PyTorchLightning/pytorch-lightning/pull/12225)) + + - Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 457b788437..e2a6b59afa 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -147,7 +147,7 @@ class ModelCheckpoint(Callback): then you should create multiple ``ModelCheckpoint`` callbacks. If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, - only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued. + only ``best_model_path`` will be reloaded and a warning will be issued. Raises: MisconfigurationException: @@ -337,13 +337,14 @@ class ModelCheckpoint(Callback): self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) self.kth_value = state_dict.get("kth_value", self.kth_value) self.best_k_models = state_dict.get("best_k_models", self.best_k_models) + self.last_model_path = state_dict.get("last_model_path", self.last_model_path) else: warnings.warn( f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," - " therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`" - " won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded." + " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and" + " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded." ) - self.last_model_path = state_dict.get("last_model_path", self.last_model_path) + self.best_model_path = state_dict["best_model_path"] def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7ae59bd342..f8ed4eb746 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1192,8 +1192,8 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): "kth_best_model_path": False, "kth_value": False, "best_k_models": False, + "last_model_path": False, "best_model_path": True, - "last_model_path": True, } for key, should_match in expected_keys.items(): if should_match: @@ -1245,6 +1245,40 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): make_assertions(cb_restore, written_ckpt) +def test_resume_training_preserves_old_ckpt_last(tmpdir): + """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from + the old checkpoint.""" + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 3, + "limit_val_batches": 0, + "enable_model_summary": False, + "logger": False, + } + mc_kwargs = { + "filename": "{step}", + "monitor": "step", + "mode": "max", + "save_last": True, + "save_top_k": 2, + "every_n_train_steps": 1, + } + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model) + # Make sure that the last checkpoint file exists in the dirpath passed (`tmpdir`) + assert set(os.listdir(tmpdir / "checkpoints")) == {"last.ckpt", "step=2.ckpt", "step=3.ckpt"} + + # Training it for 2 epochs for extra surety, that nothing gets deleted after multiple epochs + trainer_kwargs["max_epochs"] += 1 + mc_kwargs["dirpath"] = f"{tmpdir}/new" + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model, ckpt_path=f"{tmpdir}/checkpoints/step=2.ckpt") + # Ensure that the file is not deleted from the old folder + assert os.path.isfile(f"{tmpdir}/checkpoints/last.ckpt") + + def test_save_last_saves_correct_last_model_path(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir, save_last=True) mc.CHECKPOINT_NAME_LAST = "{foo}-last"