Add required states for resumed ModelCheckpoint GC (#10995)
* Add required states for resumed ModelCheckpoint GC * Add backwards compatibility with legacy cktps Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Add test to check if attrs are written to ckpt Note that we do not yet check for proper loading/reinstantiation of ModelCheckpooint based on the ckpt written to disk * Test if attributes are restored properly from ckpt * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix broken `test_callbacks_state_fit_ckpt_path` `ModelCheckpoint` is configured to save after every epoch, but `trainer.fit` is called with `max_steps = 1` Note there may be a better way of doing this, where `ModelCheckpoint` is called after `training_step` * Update test_restore.py * Update test_restore.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Check that all attributes are restored properly * revert changes, use fix on master * Convert to proper unit test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor `test_mode_checkpoint_saveload_ckpt` * First save, then load ckpt. * Instantiate ModelCheckpoint twice. Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
b1baf460d9
commit
86a3c5e2a3
|
@ -344,6 +344,10 @@ class ModelCheckpoint(Callback):
|
|||
"best_model_path": self.best_model_path,
|
||||
"current_score": self.current_score,
|
||||
"dirpath": self.dirpath,
|
||||
"best_k_models": self.best_k_models,
|
||||
"kth_best_model_path": self.kth_best_model_path,
|
||||
"kth_value": self.kth_value,
|
||||
"last_model_path": self.last_model_path,
|
||||
}
|
||||
|
||||
def on_load_checkpoint(
|
||||
|
@ -351,6 +355,10 @@ class ModelCheckpoint(Callback):
|
|||
) -> None:
|
||||
self.best_model_score = callback_state["best_model_score"]
|
||||
self.best_model_path = callback_state["best_model_path"]
|
||||
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
|
||||
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
|
||||
self.kth_value = callback_state.get("kth_value", self.kth_value)
|
||||
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
|
||||
|
||||
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
|
||||
"""Performs the main logic around saving a checkpoint.
|
||||
|
|
|
@ -1202,3 +1202,37 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"}
|
||||
|
||||
|
||||
def test_model_checkpoint_saveload_ckpt(tmpdir):
|
||||
ckpt = {
|
||||
"monitor": "random_value",
|
||||
"best_model_path": "epoch=10-step=1436.ckpt",
|
||||
"best_model_score": torch.tensor(2.246),
|
||||
"current_score": torch.tensor(1.5),
|
||||
"dirpath": tmpdir,
|
||||
"best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)},
|
||||
"kth_best_model_path": "epoch=10-step=1436.ckpt",
|
||||
"kth_value": torch.tensor(2.246),
|
||||
"last_model_path": "last2245.ckpt",
|
||||
}
|
||||
|
||||
# test on_save_checkpoint
|
||||
cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True)
|
||||
for key, val in ckpt.items():
|
||||
setattr(cb_write, key, val)
|
||||
written_ckpt = cb_write.on_save_checkpoint("", "", "")
|
||||
for state in ckpt:
|
||||
assert ckpt[state] == written_ckpt[state]
|
||||
|
||||
# test on_load_checkpoint
|
||||
# Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
|
||||
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
|
||||
# "current_score" is left as initialized, i.e. None, and can therefore also be asserted
|
||||
cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True)
|
||||
cb_restore.on_load_checkpoint("", "", written_ckpt)
|
||||
for key, val in written_ckpt.items():
|
||||
if key not in ("current_score", "dirpath", "monitor"):
|
||||
assert getattr(cb_restore, key) == val
|
||||
else:
|
||||
assert getattr(cb_restore, key) != val
|
||||
|
|
|
@ -269,8 +269,15 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
|
|||
|
||||
for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
|
||||
if isinstance(before, ModelCheckpoint):
|
||||
assert before.best_model_path == after.best_model_path
|
||||
assert before.best_model_score == after.best_model_score
|
||||
for attribute in (
|
||||
"best_model_path",
|
||||
"best_model_score",
|
||||
"best_k_models",
|
||||
"kth_best_model_path",
|
||||
"kth_value",
|
||||
"last_model_path",
|
||||
):
|
||||
assert getattr(before, attribute) == getattr(after, attribute)
|
||||
|
||||
|
||||
def test_callbacks_references_fit_ckpt_path(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue