Stop loading a few properties if checkpoint's `dirpath` has changed (#12045)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Kushashwa Ravi Shrimali 2022-02-28 22:12:09 +05:30 committed by GitHub
parent a52a6ea030
commit 02ccd874b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 19 deletions

View File

@ -628,6 +628,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))
- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)
- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))

View File

@ -22,6 +22,7 @@ import logging
import os
import re
import time
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Any, Dict, Optional
@ -144,6 +145,9 @@ class ModelCheckpoint(Callback):
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
then you should create multiple ``ModelCheckpoint`` callbacks.
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued.
Raises:
MisconfigurationException:
If ``save_top_k`` is smaller than ``-1``,
@ -352,12 +356,21 @@ class ModelCheckpoint(Callback):
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = callback_state.get("kth_value", self.kth_value)
dirpath_from_ckpt = callback_state.get("dirpath", self.dirpath)
if self.dirpath == dirpath_from_ckpt:
self.best_model_score = callback_state["best_model_score"]
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = callback_state.get("kth_value", self.kth_value)
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
else:
warnings.warn(
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
" therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`"
" won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded."
)
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
self.best_model_path = callback_state["best_model_path"]
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
"""Performs the main logic around saving a checkpoint.

View File

@ -1210,12 +1210,30 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir):
def test_model_checkpoint_saveload_ckpt(tmpdir):
def make_assertions(cb_restore, written_ckpt):
expected_keys = {
"dirpath": False,
"best_model_score": False,
"kth_best_model_path": False,
"kth_value": False,
"best_k_models": False,
"best_model_path": True,
"last_model_path": True,
}
for key, should_match in expected_keys.items():
if should_match:
assert getattr(cb_restore, key) == written_ckpt[key]
else:
assert getattr(cb_restore, key) != written_ckpt[key]
class CustomModelCheckpoint(ModelCheckpoint):
def on_load_checkpoint(self, *args, **kwargs):
assert self.dirpath is not None
return super().on_load_checkpoint(*args, **kwargs)
ckpt = {
"monitor": "random_value",
"best_model_path": "epoch=10-step=1436.ckpt",
"best_model_score": torch.tensor(2.246),
"current_score": torch.tensor(1.5),
"dirpath": tmpdir,
"best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)},
"kth_best_model_path": "epoch=10-step=1436.ckpt",
"kth_value": torch.tensor(2.246),
@ -1223,24 +1241,33 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
}
# test on_save_checkpoint
cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True)
cb_write = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, save_last=True)
for key, val in ckpt.items():
setattr(cb_write, key, val)
written_ckpt = cb_write.on_save_checkpoint("", "", "")
for state in ckpt:
assert ckpt[state] == written_ckpt[state]
# Case - 1
# test on_load_checkpoint
# Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
# "current_score" is left as initialized, i.e. None, and can therefore also be asserted
cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True)
cb_restore.on_load_checkpoint("", "", written_ckpt)
for key, val in written_ckpt.items():
if key not in ("current_score", "dirpath", "monitor"):
assert getattr(cb_restore, key) == val
else:
assert getattr(cb_restore, key) != val
# Notes:
# 1. "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
# 2. "current_score" is left as initialized, i.e. None, and can therefore also be asserted
# 3. When a different `dirpath` is passed to `ModelCheckpoint` to resume training, only
# `best_model_path` and `last_model_path` are reloaded (reloading for others is stopped).
cb_restore = ModelCheckpoint(dirpath=tmpdir + "/restore", monitor=None, save_top_k=-1, save_last=True)
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
cb_restore.on_load_checkpoint("", "", written_ckpt)
make_assertions(cb_restore, written_ckpt)
# Case - 2
# Make sure that everything runs when dirpath is not initialized explicitly
cb_restore = CustomModelCheckpoint()
cb_restore.setup(Trainer(), BoringModel())
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
cb_restore.on_load_checkpoint("", "", written_ckpt)
make_assertions(cb_restore, written_ckpt)
def test_save_last_saves_correct_last_model_path(tmpdir):