Add missing line. Add a test (#3594)
This commit is contained in:
parent
402b5fc346
commit
1223cdbaa1
|
@ -402,6 +402,7 @@ class ModelCheckpoint(Callback):
|
||||||
self._save_model(filepath, trainer, pl_module)
|
self._save_model(filepath, trainer, pl_module)
|
||||||
if self.last_model_path and self.last_model_path != filepath:
|
if self.last_model_path and self.last_model_path != filepath:
|
||||||
self._del_model(self.last_model_path)
|
self._del_model(self.last_model_path)
|
||||||
|
self.last_model_path = filepath
|
||||||
|
|
||||||
def _is_valid_monitor_key(self, metrics):
|
def _is_valid_monitor_key(self, metrics):
|
||||||
return self.monitor in metrics or len(metrics) == 0
|
return self.monitor in metrics or len(metrics) == 0
|
||||||
|
|
|
@ -150,12 +150,33 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
||||||
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
|
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_checkpoint_save_last(tmpdir):
|
||||||
|
"""Tests that save_last produces only one last checkpoint."""
|
||||||
|
model = EvalModelTemplate()
|
||||||
|
epochs = 3
|
||||||
|
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
|
||||||
|
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
early_stop_callback=False,
|
||||||
|
checkpoint_callback=model_checkpoint,
|
||||||
|
max_epochs=epochs,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
|
||||||
|
last_filename = last_filename + '.ckpt'
|
||||||
|
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
|
||||||
|
assert set(os.listdir(tmpdir)) == set(
|
||||||
|
[f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']
|
||||||
|
)
|
||||||
|
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
|
||||||
|
|
||||||
|
|
||||||
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
||||||
"""Tests that the save_last checkpoint contains the latest information."""
|
"""Tests that the save_last checkpoint contains the latest information."""
|
||||||
seed_everything(100)
|
seed_everything(100)
|
||||||
model = EvalModelTemplate()
|
model = EvalModelTemplate()
|
||||||
num_epochs = 3
|
num_epochs = 3
|
||||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
|
|
||||||
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
|
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
default_root_dir=tmpdir,
|
default_root_dir=tmpdir,
|
||||||
|
@ -164,30 +185,23 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
||||||
max_epochs=num_epochs,
|
max_epochs=num_epochs,
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {})
|
|
||||||
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
|
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
|
||||||
path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt
|
assert path_last_epoch != model_checkpoint.last_model_path
|
||||||
assert path_last_epoch != path_last
|
|
||||||
ckpt_last_epoch = torch.load(path_last_epoch)
|
ckpt_last_epoch = torch.load(path_last_epoch)
|
||||||
ckpt_last = torch.load(path_last)
|
ckpt_last = torch.load(model_checkpoint.last_model_path)
|
||||||
|
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
|
||||||
trainer_keys = ("epoch", "global_step")
|
assert all(
|
||||||
for key in trainer_keys:
|
ckpt_last["callbacks"][type(model_checkpoint)][k] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][k]
|
||||||
assert ckpt_last_epoch[key] == ckpt_last[key]
|
for k in ("best_model_score", "best_model_path")
|
||||||
|
)
|
||||||
checkpoint_callback_keys = ("best_model_score", "best_model_path")
|
|
||||||
for key in checkpoint_callback_keys:
|
|
||||||
assert (
|
|
||||||
ckpt_last["callbacks"][type(model_checkpoint)][key]
|
|
||||||
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
|
|
||||||
)
|
|
||||||
|
|
||||||
# it is easier to load the model objects than to iterate over the raw dict of tensors
|
# it is easier to load the model objects than to iterate over the raw dict of tensors
|
||||||
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
|
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
|
||||||
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
|
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path)
|
||||||
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
|
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
|
||||||
assert w0.eq(w1).all()
|
assert w0.eq(w1).all()
|
||||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
|
|
||||||
|
|
||||||
|
|
||||||
def test_ckpt_metric_names(tmpdir):
|
def test_ckpt_metric_names(tmpdir):
|
||||||
|
|
Loading…
Reference in New Issue