Prevent last checkpoint being deleted after resumed training with changed `dirpath` (#12225)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Kushashwa Ravi Shrimali 2022-03-28 21:33:00 +05:30 committed by GitHub
parent abe795e285
commit 97121a5d10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 5 deletions

View File

@ -819,6 +819,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed an issue where `ModelCheckpoint` could delete last checkpoint from the old directory when `dirpath` has changed during resumed training ([#12225](https://github.com/PyTorchLightning/pytorch-lightning/pull/12225))
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))

View File

@ -147,7 +147,7 @@ class ModelCheckpoint(Callback):
then you should create multiple ``ModelCheckpoint`` callbacks.
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued.
only ``best_model_path`` will be reloaded and a warning will be issued.
Raises:
MisconfigurationException:
@ -337,13 +337,14 @@ class ModelCheckpoint(Callback):
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = state_dict.get("kth_value", self.kth_value)
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
else:
warnings.warn(
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
" therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`"
" won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded."
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
)
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
self.best_model_path = state_dict["best_model_path"]
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover

View File

@ -1192,8 +1192,8 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
"kth_best_model_path": False,
"kth_value": False,
"best_k_models": False,
"last_model_path": False,
"best_model_path": True,
"last_model_path": True,
}
for key, should_match in expected_keys.items():
if should_match:
@ -1245,6 +1245,40 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
make_assertions(cb_restore, written_ckpt)
def test_resume_training_preserves_old_ckpt_last(tmpdir):
"""Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from
the old checkpoint."""
model = BoringModel()
trainer_kwargs = {
"default_root_dir": tmpdir,
"max_epochs": 1,
"limit_train_batches": 3,
"limit_val_batches": 0,
"enable_model_summary": False,
"logger": False,
}
mc_kwargs = {
"filename": "{step}",
"monitor": "step",
"mode": "max",
"save_last": True,
"save_top_k": 2,
"every_n_train_steps": 1,
}
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
trainer.fit(model)
# Make sure that the last checkpoint file exists in the dirpath passed (`tmpdir`)
assert set(os.listdir(tmpdir / "checkpoints")) == {"last.ckpt", "step=2.ckpt", "step=3.ckpt"}
# Training it for 2 epochs for extra surety, that nothing gets deleted after multiple epochs
trainer_kwargs["max_epochs"] += 1
mc_kwargs["dirpath"] = f"{tmpdir}/new"
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
trainer.fit(model, ckpt_path=f"{tmpdir}/checkpoints/step=2.ckpt")
# Ensure that the file is not deleted from the old folder
assert os.path.isfile(f"{tmpdir}/checkpoints/last.ckpt")
def test_save_last_saves_correct_last_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
mc.CHECKPOINT_NAME_LAST = "{foo}-last"