Prevent last checkpoint being deleted after resumed training with changed `dirpath` (#12225)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
abe795e285
commit
97121a5d10
|
@ -819,6 +819,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `ModelCheckpoint` could delete last checkpoint from the old directory when `dirpath` has changed during resumed training ([#12225](https://github.com/PyTorchLightning/pytorch-lightning/pull/12225))
|
||||
|
||||
|
||||
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))
|
||||
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ class ModelCheckpoint(Callback):
|
|||
then you should create multiple ``ModelCheckpoint`` callbacks.
|
||||
|
||||
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
||||
only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued.
|
||||
only ``best_model_path`` will be reloaded and a warning will be issued.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
@ -337,13 +337,14 @@ class ModelCheckpoint(Callback):
|
|||
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
|
||||
self.kth_value = state_dict.get("kth_value", self.kth_value)
|
||||
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
|
||||
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
|
||||
" therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`"
|
||||
" won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded."
|
||||
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
|
||||
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
|
||||
)
|
||||
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
|
||||
|
||||
self.best_model_path = state_dict["best_model_path"]
|
||||
|
||||
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
|
||||
|
|
|
@ -1192,8 +1192,8 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
|
|||
"kth_best_model_path": False,
|
||||
"kth_value": False,
|
||||
"best_k_models": False,
|
||||
"last_model_path": False,
|
||||
"best_model_path": True,
|
||||
"last_model_path": True,
|
||||
}
|
||||
for key, should_match in expected_keys.items():
|
||||
if should_match:
|
||||
|
@ -1245,6 +1245,40 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
|
|||
make_assertions(cb_restore, written_ckpt)
|
||||
|
||||
|
||||
def test_resume_training_preserves_old_ckpt_last(tmpdir):
|
||||
"""Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from
|
||||
the old checkpoint."""
|
||||
model = BoringModel()
|
||||
trainer_kwargs = {
|
||||
"default_root_dir": tmpdir,
|
||||
"max_epochs": 1,
|
||||
"limit_train_batches": 3,
|
||||
"limit_val_batches": 0,
|
||||
"enable_model_summary": False,
|
||||
"logger": False,
|
||||
}
|
||||
mc_kwargs = {
|
||||
"filename": "{step}",
|
||||
"monitor": "step",
|
||||
"mode": "max",
|
||||
"save_last": True,
|
||||
"save_top_k": 2,
|
||||
"every_n_train_steps": 1,
|
||||
}
|
||||
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
|
||||
trainer.fit(model)
|
||||
# Make sure that the last checkpoint file exists in the dirpath passed (`tmpdir`)
|
||||
assert set(os.listdir(tmpdir / "checkpoints")) == {"last.ckpt", "step=2.ckpt", "step=3.ckpt"}
|
||||
|
||||
# Training it for 2 epochs for extra surety, that nothing gets deleted after multiple epochs
|
||||
trainer_kwargs["max_epochs"] += 1
|
||||
mc_kwargs["dirpath"] = f"{tmpdir}/new"
|
||||
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
|
||||
trainer.fit(model, ckpt_path=f"{tmpdir}/checkpoints/step=2.ckpt")
|
||||
# Ensure that the file is not deleted from the old folder
|
||||
assert os.path.isfile(f"{tmpdir}/checkpoints/last.ckpt")
|
||||
|
||||
|
||||
def test_save_last_saves_correct_last_model_path(tmpdir):
|
||||
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
mc.CHECKPOINT_NAME_LAST = "{foo}-last"
|
||||
|
|
Loading…
Reference in New Issue