Set the state before saving "last" or "none" checkpoints (#11481)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
3b699030d5
commit
b8d2c65a37
|
@ -451,6 +451,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
|
||||
|
||||
|
||||
- Fixed bug where the path for "last" checkpoints was not getting saved correctly which caused newer runs to not remove the previous "last" checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
|
||||
|
||||
|
||||
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
|
||||
|
||||
|
||||
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))
|
||||
|
||||
|
||||
|
|
|
@ -641,12 +641,11 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
|
||||
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
|
||||
# set the last model path before saving because it will be part of the state.
|
||||
previous, self.last_model_path = self.last_model_path, filepath
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
|
||||
if self.last_model_path and self.last_model_path != filepath:
|
||||
trainer.strategy.remove_checkpoint(self.last_model_path)
|
||||
|
||||
self.last_model_path = filepath
|
||||
if previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
if self.monitor is None or self.save_top_k == 0:
|
||||
|
@ -666,12 +665,11 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
|
||||
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
|
||||
# set the best model path before saving because it will be part of the state.
|
||||
previous, self.best_model_path = self.best_model_path, filepath
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
|
||||
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
|
||||
trainer.strategy.remove_checkpoint(self.best_model_path)
|
||||
|
||||
self.best_model_path = filepath
|
||||
if self.save_top_k == 1 and previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
|
||||
return self.monitor in metrics or len(metrics) == 0
|
||||
|
|
|
@ -1234,3 +1234,30 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
|
|||
assert getattr(cb_restore, key) == val
|
||||
else:
|
||||
assert getattr(cb_restore, key) != val
|
||||
|
||||
|
||||
def test_save_last_saves_correct_last_model_path(tmpdir):
|
||||
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
mc.CHECKPOINT_NAME_LAST = "{foo}-last"
|
||||
trainer = Trainer(callbacks=mc)
|
||||
trainer.strategy.connect(BoringModel())
|
||||
|
||||
mc._save_last_checkpoint(trainer, {"foo": 1})
|
||||
expected = "foo=1-last.ckpt"
|
||||
assert os.listdir(tmpdir) == [expected]
|
||||
full_path = str(tmpdir / expected)
|
||||
ckpt = torch.load(full_path)
|
||||
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path
|
||||
|
||||
|
||||
def test_none_monitor_saves_correct_best_model_path(tmpdir):
|
||||
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
|
||||
trainer = Trainer(callbacks=mc)
|
||||
trainer.strategy.connect(BoringModel())
|
||||
|
||||
mc._save_none_monitor_checkpoint(trainer, {})
|
||||
expected = "epoch=0-step=0.ckpt"
|
||||
assert os.listdir(tmpdir) == [expected]
|
||||
full_path = str(tmpdir / expected)
|
||||
ckpt = torch.load(full_path)
|
||||
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path
|
||||
|
|
|
@ -40,6 +40,7 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
|||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
|
||||
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
|
||||
trainer.fit(model)
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path`
|
||||
assert callback.state == 111
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
|
|
Loading…
Reference in New Issue