Versioning of last checkpoins (#12902)

* last checkpoint versioning

* changelog

* Simplify test

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
otaj 2022-04-29 07:13:50 +02:00 committed by GitHub
parent 74d46d655d
commit c461854fa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 3 deletions

View File

@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
-
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))
-
@ -53,7 +53,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Make positional arguments required for classes passed into the `add_argparse_args` function. ([#12504](https://github.com/PyTorchLightning/pytorch-lightning/pull/12504))
-
-

View File

@ -640,6 +640,12 @@ class ModelCheckpoint(Callback):
return
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != self.last_model_path:
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt)
version_cnt += 1
# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
self._save_checkpoint(trainer, filepath)

View File

@ -1299,6 +1299,23 @@ def test_save_last_saves_correct_last_model_path(tmpdir):
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path
def test_save_last_versioning(tmpdir):
model = BoringModel()
for _ in range(2):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True)
trainer = Trainer(
max_epochs=2,
callbacks=mc,
limit_train_batches=1,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
)
trainer.fit(model)
assert {"last.ckpt", "last-v1.ckpt"} == set(os.listdir(tmpdir))
def test_none_monitor_saves_correct_best_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
trainer = Trainer(callbacks=mc)

View File

@ -76,4 +76,4 @@ def test_checkpoint_plugin_called(tmpdir):
trainer.test(model, ckpt_path=ck.last_model_path)
checkpoint_plugin.load_checkpoint.assert_called_once()
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")