Set the state before saving "last" or "none" checkpoints (#11481)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Nithin Rao 2022-02-02 14:07:05 -08:00 committed by GitHub
parent 3b699030d5
commit b8d2c65a37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 10 deletions

View File

@ -451,6 +451,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
- Fixed bug where the path for "last" checkpoints was not getting saved correctly which caused newer runs to not remove the previous "last" checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))

View File

@ -641,12 +641,11 @@ class ModelCheckpoint(Callback):
return
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
if self.last_model_path and self.last_model_path != filepath:
trainer.strategy.remove_checkpoint(self.last_model_path)
self.last_model_path = filepath
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is None or self.save_top_k == 0:
@ -666,12 +665,11 @@ class ModelCheckpoint(Callback):
return
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
trainer.strategy.remove_checkpoint(self.best_model_path)
self.best_model_path = filepath
if self.save_top_k == 1 and previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
return self.monitor in metrics or len(metrics) == 0

View File

@ -1234,3 +1234,30 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
assert getattr(cb_restore, key) == val
else:
assert getattr(cb_restore, key) != val
def test_save_last_saves_correct_last_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
mc.CHECKPOINT_NAME_LAST = "{foo}-last"
trainer = Trainer(callbacks=mc)
trainer.strategy.connect(BoringModel())
mc._save_last_checkpoint(trainer, {"foo": 1})
expected = "foo=1-last.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
ckpt = torch.load(full_path)
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path
def test_none_monitor_saves_correct_best_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
trainer = Trainer(callbacks=mc)
trainer.strategy.connect(BoringModel())
mc._save_none_monitor_checkpoint(trainer, {})
expected = "epoch=0-step=0.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
ckpt = torch.load(full_path)
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path

View File

@ -40,6 +40,7 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path`
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None