Add missing line. Add a test (#3594)
This commit is contained in:
parent
402b5fc346
commit
1223cdbaa1
|
@ -402,6 +402,7 @@ class ModelCheckpoint(Callback):
|
|||
self._save_model(filepath, trainer, pl_module)
|
||||
if self.last_model_path and self.last_model_path != filepath:
|
||||
self._del_model(self.last_model_path)
|
||||
self.last_model_path = filepath
|
||||
|
||||
def _is_valid_monitor_key(self, metrics):
|
||||
return self.monitor in metrics or len(metrics) == 0
|
||||
|
|
|
@ -150,12 +150,33 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
|||
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last(tmpdir):
|
||||
"""Tests that save_last produces only one last checkpoint."""
|
||||
model = EvalModelTemplate()
|
||||
epochs = 3
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
|
||||
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=model_checkpoint,
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
|
||||
last_filename = last_filename + '.ckpt'
|
||||
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
|
||||
assert set(os.listdir(tmpdir)) == set(
|
||||
[f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']
|
||||
)
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
||||
"""Tests that the save_last checkpoint contains the latest information."""
|
||||
seed_everything(100)
|
||||
model = EvalModelTemplate()
|
||||
num_epochs = 3
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
|
||||
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -164,30 +185,23 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
max_epochs=num_epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {})
|
||||
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
|
||||
path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt
|
||||
assert path_last_epoch != path_last
|
||||
|
||||
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
|
||||
assert path_last_epoch != model_checkpoint.last_model_path
|
||||
|
||||
ckpt_last_epoch = torch.load(path_last_epoch)
|
||||
ckpt_last = torch.load(path_last)
|
||||
|
||||
trainer_keys = ("epoch", "global_step")
|
||||
for key in trainer_keys:
|
||||
assert ckpt_last_epoch[key] == ckpt_last[key]
|
||||
|
||||
checkpoint_callback_keys = ("best_model_score", "best_model_path")
|
||||
for key in checkpoint_callback_keys:
|
||||
assert (
|
||||
ckpt_last["callbacks"][type(model_checkpoint)][key]
|
||||
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
|
||||
)
|
||||
ckpt_last = torch.load(model_checkpoint.last_model_path)
|
||||
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
|
||||
assert all(
|
||||
ckpt_last["callbacks"][type(model_checkpoint)][k] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][k]
|
||||
for k in ("best_model_score", "best_model_path")
|
||||
)
|
||||
|
||||
# it is easier to load the model objects than to iterate over the raw dict of tensors
|
||||
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
|
||||
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
|
||||
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path)
|
||||
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
|
||||
assert w0.eq(w1).all()
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
|
||||
|
||||
|
||||
def test_ckpt_metric_names(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue