From 02ccd874b9eda885df7752710e605e259d579d9c Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 28 Feb 2022 22:12:09 +0530 Subject: [PATCH] Stop loading a few properties if checkpoint's `dirpath` has changed (#12045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rohit Gupta Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 + .../callbacks/model_checkpoint.py | 23 ++++++-- tests/checkpointing/test_model_checkpoint.py | 55 ++++++++++++++----- 3 files changed, 62 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56ff447628..9606b7d653 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -628,6 +628,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045)) + + - Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752) - Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1e231abab2..6ae01add66 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -22,6 +22,7 @@ import logging import os import re import time +import warnings from copy import deepcopy from datetime import timedelta from typing import Any, Dict, Optional @@ -144,6 +145,9 @@ class ModelCheckpoint(Callback): If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ``ModelCheckpoint`` callbacks. + If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, + only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued. + Raises: MisconfigurationException: If ``save_top_k`` is smaller than ``-1``, @@ -352,12 +356,21 @@ class ModelCheckpoint(Callback): def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: - self.best_model_score = callback_state["best_model_score"] - self.best_model_path = callback_state["best_model_path"] - self.best_k_models = callback_state.get("best_k_models", self.best_k_models) - self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path) - self.kth_value = callback_state.get("kth_value", self.kth_value) + dirpath_from_ckpt = callback_state.get("dirpath", self.dirpath) + + if self.dirpath == dirpath_from_ckpt: + self.best_model_score = callback_state["best_model_score"] + self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path) + self.kth_value = callback_state.get("kth_value", self.kth_value) + self.best_k_models = callback_state.get("best_k_models", self.best_k_models) + else: + warnings.warn( + f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," + " therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`" + " won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded." + ) self.last_model_path = callback_state.get("last_model_path", self.last_model_path) + self.best_model_path = callback_state["best_model_path"] def save_checkpoint(self, trainer: "pl.Trainer") -> None: """Performs the main logic around saving a checkpoint. diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 2c65426534..b9e63d28f4 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1210,12 +1210,30 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): def test_model_checkpoint_saveload_ckpt(tmpdir): + def make_assertions(cb_restore, written_ckpt): + expected_keys = { + "dirpath": False, + "best_model_score": False, + "kth_best_model_path": False, + "kth_value": False, + "best_k_models": False, + "best_model_path": True, + "last_model_path": True, + } + for key, should_match in expected_keys.items(): + if should_match: + assert getattr(cb_restore, key) == written_ckpt[key] + else: + assert getattr(cb_restore, key) != written_ckpt[key] + + class CustomModelCheckpoint(ModelCheckpoint): + def on_load_checkpoint(self, *args, **kwargs): + assert self.dirpath is not None + return super().on_load_checkpoint(*args, **kwargs) + ckpt = { - "monitor": "random_value", "best_model_path": "epoch=10-step=1436.ckpt", "best_model_score": torch.tensor(2.246), - "current_score": torch.tensor(1.5), - "dirpath": tmpdir, "best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)}, "kth_best_model_path": "epoch=10-step=1436.ckpt", "kth_value": torch.tensor(2.246), @@ -1223,24 +1241,33 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): } # test on_save_checkpoint - cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True) + cb_write = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, save_last=True) for key, val in ckpt.items(): setattr(cb_write, key, val) written_ckpt = cb_write.on_save_checkpoint("", "", "") for state in ckpt: assert ckpt[state] == written_ckpt[state] + # Case - 1 # test on_load_checkpoint - # Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint. - # We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them. - # "current_score" is left as initialized, i.e. None, and can therefore also be asserted - cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True) - cb_restore.on_load_checkpoint("", "", written_ckpt) - for key, val in written_ckpt.items(): - if key not in ("current_score", "dirpath", "monitor"): - assert getattr(cb_restore, key) == val - else: - assert getattr(cb_restore, key) != val + # Notes: + # 1. "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint. + # We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them. + # 2. "current_score" is left as initialized, i.e. None, and can therefore also be asserted + # 3. When a different `dirpath` is passed to `ModelCheckpoint` to resume training, only + # `best_model_path` and `last_model_path` are reloaded (reloading for others is stopped). + cb_restore = ModelCheckpoint(dirpath=tmpdir + "/restore", monitor=None, save_top_k=-1, save_last=True) + with pytest.warns(UserWarning, match="The dirpath has changed from*"): + cb_restore.on_load_checkpoint("", "", written_ckpt) + make_assertions(cb_restore, written_ckpt) + + # Case - 2 + # Make sure that everything runs when dirpath is not initialized explicitly + cb_restore = CustomModelCheckpoint() + cb_restore.setup(Trainer(), BoringModel()) + with pytest.warns(UserWarning, match="The dirpath has changed from*"): + cb_restore.on_load_checkpoint("", "", written_ckpt) + make_assertions(cb_restore, written_ckpt) def test_save_last_saves_correct_last_model_path(tmpdir):