diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b6d7f61bc..07b9dbbe56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -451,6 +451,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) +- Fixed bug where the path for "last" checkpoints was not getting saved correctly which caused newer runs to not remove the previous "last" checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481)) + + +- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481)) + + - Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 442bd7a692..f857bbdeac 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -641,12 +641,11 @@ class ModelCheckpoint(Callback): return filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) + # set the last model path before saving because it will be part of the state. + previous, self.last_model_path = self.last_model_path, filepath trainer.save_checkpoint(filepath, self.save_weights_only) - - if self.last_model_path and self.last_model_path != filepath: - trainer.strategy.remove_checkpoint(self.last_model_path) - - self.last_model_path = filepath + if previous and previous != filepath: + trainer.strategy.remove_checkpoint(previous) def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: if self.monitor is None or self.save_top_k == 0: @@ -666,12 +665,11 @@ class ModelCheckpoint(Callback): return filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) + # set the best model path before saving because it will be part of the state. + previous, self.best_model_path = self.best_model_path, filepath trainer.save_checkpoint(filepath, self.save_weights_only) - - if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath: - trainer.strategy.remove_checkpoint(self.best_model_path) - - self.best_model_path = filepath + if self.save_top_k == 1 and previous and previous != filepath: + trainer.strategy.remove_checkpoint(previous) def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool: return self.monitor in metrics or len(metrics) == 0 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 315c33bb6c..94431030b5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1234,3 +1234,30 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): assert getattr(cb_restore, key) == val else: assert getattr(cb_restore, key) != val + + +def test_save_last_saves_correct_last_model_path(tmpdir): + mc = ModelCheckpoint(dirpath=tmpdir, save_last=True) + mc.CHECKPOINT_NAME_LAST = "{foo}-last" + trainer = Trainer(callbacks=mc) + trainer.strategy.connect(BoringModel()) + + mc._save_last_checkpoint(trainer, {"foo": 1}) + expected = "foo=1-last.ckpt" + assert os.listdir(tmpdir) == [expected] + full_path = str(tmpdir / expected) + ckpt = torch.load(full_path) + assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path + + +def test_none_monitor_saves_correct_best_model_path(tmpdir): + mc = ModelCheckpoint(dirpath=tmpdir, monitor=None) + trainer = Trainer(callbacks=mc) + trainer.strategy.connect(BoringModel()) + + mc._save_none_monitor_checkpoint(trainer, {}) + expected = "epoch=0-step=0.ckpt" + assert os.listdir(tmpdir) == [expected] + full_path = str(tmpdir / expected) + ckpt = torch.load(full_path) + assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path diff --git a/tests/deprecated_api/test_remove_2-0.py b/tests/deprecated_api/test_remove_2-0.py index ed0520b9b8..8b7dbe7488 100644 --- a/tests/deprecated_api/test_remove_2-0.py +++ b/tests/deprecated_api/test_remove_2-0.py @@ -40,6 +40,7 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir): assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"): trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path` assert callback.state == 111 assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None