diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ebf7017e26..43ef28c7d8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -402,6 +402,7 @@ class ModelCheckpoint(Callback): self._save_model(filepath, trainer, pl_module) if self.last_model_path and self.last_model_path != filepath: self._del_model(self.last_model_path) + self.last_model_path = filepath def _is_valid_monitor_key(self, metrics): return self.monitor in metrics or len(metrics) == 0 diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 39c733922a..f7db35fe04 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -150,12 +150,33 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == tmpdir / 'test-name-v3.ckpt' +def test_model_checkpoint_save_last(tmpdir): + """Tests that save_last produces only one last checkpoint.""" + model = EvalModelTemplate() + epochs = 3 + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' + model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=model_checkpoint, + max_epochs=epochs, + ) + trainer.fit(model) + last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) + last_filename = last_filename + '.ckpt' + assert str(tmpdir / last_filename) == model_checkpoint.last_model_path + assert set(os.listdir(tmpdir)) == set( + [f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs'] + ) + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' + + def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) model = EvalModelTemplate() num_epochs = 3 - ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, @@ -164,30 +185,23 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): max_epochs=num_epochs, ) trainer.fit(model) - last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {}) - path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt - path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt - assert path_last_epoch != path_last + + path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) + assert path_last_epoch != model_checkpoint.last_model_path + ckpt_last_epoch = torch.load(path_last_epoch) - ckpt_last = torch.load(path_last) - - trainer_keys = ("epoch", "global_step") - for key in trainer_keys: - assert ckpt_last_epoch[key] == ckpt_last[key] - - checkpoint_callback_keys = ("best_model_score", "best_model_path") - for key in checkpoint_callback_keys: - assert ( - ckpt_last["callbacks"][type(model_checkpoint)][key] - == ckpt_last_epoch["callbacks"][type(model_checkpoint)][key] - ) + ckpt_last = torch.load(model_checkpoint.last_model_path) + assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) + assert all( + ckpt_last["callbacks"][type(model_checkpoint)][k] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][k] + for k in ("best_model_score", "best_model_path") + ) # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) - model_last = EvalModelTemplate.load_from_checkpoint(path_last) + model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() - ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' def test_ckpt_metric_names(tmpdir):