diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7707f8a377..d12752eabf 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -344,6 +344,10 @@ class ModelCheckpoint(Callback): "best_model_path": self.best_model_path, "current_score": self.current_score, "dirpath": self.dirpath, + "best_k_models": self.best_k_models, + "kth_best_model_path": self.kth_best_model_path, + "kth_value": self.kth_value, + "last_model_path": self.last_model_path, } def on_load_checkpoint( @@ -351,6 +355,10 @@ class ModelCheckpoint(Callback): ) -> None: self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] + self.best_k_models = callback_state.get("best_k_models", self.best_k_models) + self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path) + self.kth_value = callback_state.get("kth_value", self.kth_value) + self.last_model_path = callback_state.get("last_model_path", self.last_model_path) def save_checkpoint(self, trainer: "pl.Trainer") -> None: """Performs the main logic around saving a checkpoint. diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index fa08057733..a332490585 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1202,3 +1202,37 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): ) trainer.fit(model) assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} + + +def test_model_checkpoint_saveload_ckpt(tmpdir): + ckpt = { + "monitor": "random_value", + "best_model_path": "epoch=10-step=1436.ckpt", + "best_model_score": torch.tensor(2.246), + "current_score": torch.tensor(1.5), + "dirpath": tmpdir, + "best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)}, + "kth_best_model_path": "epoch=10-step=1436.ckpt", + "kth_value": torch.tensor(2.246), + "last_model_path": "last2245.ckpt", + } + + # test on_save_checkpoint + cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True) + for key, val in ckpt.items(): + setattr(cb_write, key, val) + written_ckpt = cb_write.on_save_checkpoint("", "", "") + for state in ckpt: + assert ckpt[state] == written_ckpt[state] + + # test on_load_checkpoint + # Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint. + # We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them. + # "current_score" is left as initialized, i.e. None, and can therefore also be asserted + cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True) + cb_restore.on_load_checkpoint("", "", written_ckpt) + for key, val in written_ckpt.items(): + if key not in ("current_score", "dirpath", "monitor"): + assert getattr(cb_restore, key) == val + else: + assert getattr(cb_restore, key) != val diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index b801041410..9d6ecb08cd 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -269,8 +269,15 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): - assert before.best_model_path == after.best_model_path - assert before.best_model_score == after.best_model_score + for attribute in ( + "best_model_path", + "best_model_score", + "best_k_models", + "kth_best_model_path", + "kth_value", + "last_model_path", + ): + assert getattr(before, attribute) == getattr(after, attribute) def test_callbacks_references_fit_ckpt_path(tmpdir):