Add missing line. Add a test (#3594)

This commit is contained in:
Carlos Mocholí 2020-09-22 04:17:51 +02:00 committed by GitHub
parent 402b5fc346
commit 1223cdbaa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 19 deletions

View File

@ -402,6 +402,7 @@ class ModelCheckpoint(Callback):
self._save_model(filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != filepath:
self._del_model(self.last_model_path)
self.last_model_path = filepath
def _is_valid_monitor_key(self, metrics):
return self.monitor in metrics or len(metrics) == 0

View File

@ -150,12 +150,33 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=epochs,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set(
[f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']
)
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
@ -164,30 +185,23 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
max_epochs=num_epochs,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {})
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt
assert path_last_epoch != path_last
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
assert path_last_epoch != model_checkpoint.last_model_path
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
trainer_keys = ("epoch", "global_step")
for key in trainer_keys:
assert ckpt_last_epoch[key] == ckpt_last[key]
checkpoint_callback_keys = ("best_model_score", "best_model_path")
for key in checkpoint_callback_keys:
assert (
ckpt_last["callbacks"][type(model_checkpoint)][key]
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
)
ckpt_last = torch.load(model_checkpoint.last_model_path)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
assert all(
ckpt_last["callbacks"][type(model_checkpoint)][k] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][k]
for k in ("best_model_score", "best_model_path")
)
# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_ckpt_metric_names(tmpdir):