Make `ModelCheckpoint._format_checkpoint_name` an instance method (#19054)

This commit is contained in:
Adrian Wälchli 2023-11-23 01:05:48 +01:00 committed by GitHub
parent dbea69be61
commit 9a26da8081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 16 deletions

View File

@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))
- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))
## [2.1.2] - 2023-11-15

View File

@ -524,9 +524,8 @@ class ModelCheckpoint(Checkpoint):
return should_update_best_and_save
@classmethod
def _format_checkpoint_name(
cls,
self,
filename: Optional[str],
metrics: Dict[str, Tensor],
prefix: str = "",
@ -534,7 +533,7 @@ class ModelCheckpoint(Checkpoint):
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}"
filename = "{epoch}" + self.CHECKPOINT_JOIN_CHAR + "{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
@ -547,7 +546,7 @@ class ModelCheckpoint(Checkpoint):
name = group[1:]
if auto_insert_metric_name:
filename = filename.replace(group, name + cls.CHECKPOINT_EQUALS_CHAR + "{" + name)
filename = filename.replace(group, name + self.CHECKPOINT_EQUALS_CHAR + "{" + name)
# support for dots: https://stackoverflow.com/a/7934969
filename = filename.replace(group, f"{{0[{name}]")
@ -557,7 +556,7 @@ class ModelCheckpoint(Checkpoint):
filename = filename.format(metrics)
if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
return filename

View File

@ -402,34 +402,36 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
model_checkpoint = ModelCheckpoint(dirpath=tmpdir)
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
ckpt_name = model_checkpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
assert ckpt_name == "epoch=3-step=2"
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
ckpt_name = model_checkpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
assert ckpt_name == "test-epoch=3-step=2"
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test")
ckpt_name = model_checkpoint._format_checkpoint_name("ckpt", {}, prefix="test")
assert ckpt_name == "test-ckpt"
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
assert ckpt_name == "epoch=003-acc=0.03"
# one metric name is substring of another
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
assert ckpt_name == "epoch=003-epoch_test=003"
# prefix
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
model_checkpoint.CHECKPOINT_JOIN_CHAR = "@"
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
assert ckpt_name == "test@epoch=3,acc=0.03000"
monkeypatch.undo()
# non-default char for equals sign
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
model_checkpoint.CHECKPOINT_EQUALS_CHAR = ":"
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
assert ckpt_name == "epoch:003-acc:0.03"
monkeypatch.undo()
@ -454,13 +456,13 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt"
# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
ckpt_name = model_checkpoint._format_checkpoint_name(
"epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False
)
assert ckpt_name == "epoch=003-val_acc=0.03"
# dots in the metric name
ckpt_name = ModelCheckpoint._format_checkpoint_name(
ckpt_name = model_checkpoint._format_checkpoint_name(
"mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False
)
assert ckpt_name == "mAP@0.50=0.2000"