Make `ModelCheckpoint._format_checkpoint_name` an instance method (#19054)
This commit is contained in:
parent
dbea69be61
commit
9a26da8081
|
@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))
|
||||
|
||||
|
||||
- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))
|
||||
|
||||
|
||||
## [2.1.2] - 2023-11-15
|
||||
|
||||
|
|
|
@ -524,9 +524,8 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
return should_update_best_and_save
|
||||
|
||||
@classmethod
|
||||
def _format_checkpoint_name(
|
||||
cls,
|
||||
self,
|
||||
filename: Optional[str],
|
||||
metrics: Dict[str, Tensor],
|
||||
prefix: str = "",
|
||||
|
@ -534,7 +533,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
) -> str:
|
||||
if not filename:
|
||||
# filename is not set, use default name
|
||||
filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}"
|
||||
filename = "{epoch}" + self.CHECKPOINT_JOIN_CHAR + "{step}"
|
||||
|
||||
# check and parse user passed keys in the string
|
||||
groups = re.findall(r"(\{.*?)[:\}]", filename)
|
||||
|
@ -547,7 +546,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
name = group[1:]
|
||||
|
||||
if auto_insert_metric_name:
|
||||
filename = filename.replace(group, name + cls.CHECKPOINT_EQUALS_CHAR + "{" + name)
|
||||
filename = filename.replace(group, name + self.CHECKPOINT_EQUALS_CHAR + "{" + name)
|
||||
|
||||
# support for dots: https://stackoverflow.com/a/7934969
|
||||
filename = filename.replace(group, f"{{0[{name}]")
|
||||
|
@ -557,7 +556,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
filename = filename.format(metrics)
|
||||
|
||||
if prefix:
|
||||
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
|
||||
filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
|
||||
|
||||
return filename
|
||||
|
||||
|
|
|
@ -402,34 +402,36 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
|||
|
||||
|
||||
def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
|
||||
model_checkpoint = ModelCheckpoint(dirpath=tmpdir)
|
||||
|
||||
# empty filename:
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
|
||||
assert ckpt_name == "epoch=3-step=2"
|
||||
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
|
||||
assert ckpt_name == "test-epoch=3-step=2"
|
||||
|
||||
# no groups case:
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test")
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("ckpt", {}, prefix="test")
|
||||
assert ckpt_name == "test-ckpt"
|
||||
|
||||
# no prefix
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
|
||||
assert ckpt_name == "epoch=003-acc=0.03"
|
||||
|
||||
# one metric name is substring of another
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
|
||||
assert ckpt_name == "epoch=003-epoch_test=003"
|
||||
|
||||
# prefix
|
||||
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@")
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
|
||||
model_checkpoint.CHECKPOINT_JOIN_CHAR = "@"
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
|
||||
assert ckpt_name == "test@epoch=3,acc=0.03000"
|
||||
monkeypatch.undo()
|
||||
|
||||
# non-default char for equals sign
|
||||
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":")
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
|
||||
model_checkpoint.CHECKPOINT_EQUALS_CHAR = ":"
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
|
||||
assert ckpt_name == "epoch:003-acc:0.03"
|
||||
monkeypatch.undo()
|
||||
|
||||
|
@ -454,13 +456,13 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
|
|||
assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt"
|
||||
|
||||
# auto_insert_metric_name=False
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name(
|
||||
"epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False
|
||||
)
|
||||
assert ckpt_name == "epoch=003-val_acc=0.03"
|
||||
|
||||
# dots in the metric name
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(
|
||||
ckpt_name = model_checkpoint._format_checkpoint_name(
|
||||
"mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False
|
||||
)
|
||||
assert ckpt_name == "mAP@0.50=0.2000"
|
||||
|
|
Loading…
Reference in New Issue