Versioning of last checkpoins (#12902)
* last checkpoint versioning * changelog * Simplify test * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
74d46d655d
commit
c461854fa7
|
@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
|
||||
|
||||
|
||||
-
|
||||
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))
|
||||
|
||||
|
||||
-
|
||||
|
@ -53,7 +53,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Make positional arguments required for classes passed into the `add_argparse_args` function. ([#12504](https://github.com/PyTorchLightning/pytorch-lightning/pull/12504))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -640,6 +640,12 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
|
||||
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
|
||||
|
||||
version_cnt = self.STARTING_VERSION
|
||||
while self.file_exists(filepath, trainer) and filepath != self.last_model_path:
|
||||
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt)
|
||||
version_cnt += 1
|
||||
|
||||
# set the last model path before saving because it will be part of the state.
|
||||
previous, self.last_model_path = self.last_model_path, filepath
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
|
|
|
@ -1299,6 +1299,23 @@ def test_save_last_saves_correct_last_model_path(tmpdir):
|
|||
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path
|
||||
|
||||
|
||||
def test_save_last_versioning(tmpdir):
|
||||
model = BoringModel()
|
||||
for _ in range(2):
|
||||
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True)
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
callbacks=mc,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=0,
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert {"last.ckpt", "last-v1.ckpt"} == set(os.listdir(tmpdir))
|
||||
|
||||
|
||||
def test_none_monitor_saves_correct_best_model_path(tmpdir):
|
||||
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
|
||||
trainer = Trainer(callbacks=mc)
|
||||
|
|
|
@ -76,4 +76,4 @@ def test_checkpoint_plugin_called(tmpdir):
|
|||
|
||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||
checkpoint_plugin.load_checkpoint.assert_called_once()
|
||||
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")
|
||||
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")
|
||||
|
|
Loading…
Reference in New Issue