Stop loading a few properties if checkpoint's `dirpath` has changed (#12045)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
a52a6ea030
commit
02ccd874b9
|
@ -628,6 +628,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))
|
||||
|
||||
|
||||
- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)
|
||||
|
||||
- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))
|
||||
|
|
|
@ -22,6 +22,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
@ -144,6 +145,9 @@ class ModelCheckpoint(Callback):
|
|||
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
|
||||
then you should create multiple ``ModelCheckpoint`` callbacks.
|
||||
|
||||
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
|
||||
only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``save_top_k`` is smaller than ``-1``,
|
||||
|
@ -352,12 +356,21 @@ class ModelCheckpoint(Callback):
|
|||
def on_load_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
|
||||
) -> None:
|
||||
self.best_model_score = callback_state["best_model_score"]
|
||||
self.best_model_path = callback_state["best_model_path"]
|
||||
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
|
||||
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
|
||||
self.kth_value = callback_state.get("kth_value", self.kth_value)
|
||||
dirpath_from_ckpt = callback_state.get("dirpath", self.dirpath)
|
||||
|
||||
if self.dirpath == dirpath_from_ckpt:
|
||||
self.best_model_score = callback_state["best_model_score"]
|
||||
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
|
||||
self.kth_value = callback_state.get("kth_value", self.kth_value)
|
||||
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
|
||||
" therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`"
|
||||
" won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded."
|
||||
)
|
||||
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
|
||||
self.best_model_path = callback_state["best_model_path"]
|
||||
|
||||
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
|
||||
"""Performs the main logic around saving a checkpoint.
|
||||
|
|
|
@ -1210,12 +1210,30 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir):
|
|||
|
||||
|
||||
def test_model_checkpoint_saveload_ckpt(tmpdir):
|
||||
def make_assertions(cb_restore, written_ckpt):
|
||||
expected_keys = {
|
||||
"dirpath": False,
|
||||
"best_model_score": False,
|
||||
"kth_best_model_path": False,
|
||||
"kth_value": False,
|
||||
"best_k_models": False,
|
||||
"best_model_path": True,
|
||||
"last_model_path": True,
|
||||
}
|
||||
for key, should_match in expected_keys.items():
|
||||
if should_match:
|
||||
assert getattr(cb_restore, key) == written_ckpt[key]
|
||||
else:
|
||||
assert getattr(cb_restore, key) != written_ckpt[key]
|
||||
|
||||
class CustomModelCheckpoint(ModelCheckpoint):
|
||||
def on_load_checkpoint(self, *args, **kwargs):
|
||||
assert self.dirpath is not None
|
||||
return super().on_load_checkpoint(*args, **kwargs)
|
||||
|
||||
ckpt = {
|
||||
"monitor": "random_value",
|
||||
"best_model_path": "epoch=10-step=1436.ckpt",
|
||||
"best_model_score": torch.tensor(2.246),
|
||||
"current_score": torch.tensor(1.5),
|
||||
"dirpath": tmpdir,
|
||||
"best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)},
|
||||
"kth_best_model_path": "epoch=10-step=1436.ckpt",
|
||||
"kth_value": torch.tensor(2.246),
|
||||
|
@ -1223,24 +1241,33 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
|
|||
}
|
||||
|
||||
# test on_save_checkpoint
|
||||
cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True)
|
||||
cb_write = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, save_last=True)
|
||||
for key, val in ckpt.items():
|
||||
setattr(cb_write, key, val)
|
||||
written_ckpt = cb_write.on_save_checkpoint("", "", "")
|
||||
for state in ckpt:
|
||||
assert ckpt[state] == written_ckpt[state]
|
||||
|
||||
# Case - 1
|
||||
# test on_load_checkpoint
|
||||
# Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
|
||||
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
|
||||
# "current_score" is left as initialized, i.e. None, and can therefore also be asserted
|
||||
cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True)
|
||||
cb_restore.on_load_checkpoint("", "", written_ckpt)
|
||||
for key, val in written_ckpt.items():
|
||||
if key not in ("current_score", "dirpath", "monitor"):
|
||||
assert getattr(cb_restore, key) == val
|
||||
else:
|
||||
assert getattr(cb_restore, key) != val
|
||||
# Notes:
|
||||
# 1. "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
|
||||
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
|
||||
# 2. "current_score" is left as initialized, i.e. None, and can therefore also be asserted
|
||||
# 3. When a different `dirpath` is passed to `ModelCheckpoint` to resume training, only
|
||||
# `best_model_path` and `last_model_path` are reloaded (reloading for others is stopped).
|
||||
cb_restore = ModelCheckpoint(dirpath=tmpdir + "/restore", monitor=None, save_top_k=-1, save_last=True)
|
||||
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
|
||||
cb_restore.on_load_checkpoint("", "", written_ckpt)
|
||||
make_assertions(cb_restore, written_ckpt)
|
||||
|
||||
# Case - 2
|
||||
# Make sure that everything runs when dirpath is not initialized explicitly
|
||||
cb_restore = CustomModelCheckpoint()
|
||||
cb_restore.setup(Trainer(), BoringModel())
|
||||
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
|
||||
cb_restore.on_load_checkpoint("", "", written_ckpt)
|
||||
make_assertions(cb_restore, written_ckpt)
|
||||
|
||||
|
||||
def test_save_last_saves_correct_last_model_path(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue