diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9e6411cd7b..2a88d29f5b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072)) +- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054)) + ## [2.1.2] - 2023-11-15 diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index f9c323f642..565aefaf3a 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -524,9 +524,8 @@ class ModelCheckpoint(Checkpoint): return should_update_best_and_save - @classmethod def _format_checkpoint_name( - cls, + self, filename: Optional[str], metrics: Dict[str, Tensor], prefix: str = "", @@ -534,7 +533,7 @@ class ModelCheckpoint(Checkpoint): ) -> str: if not filename: # filename is not set, use default name - filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}" + filename = "{epoch}" + self.CHECKPOINT_JOIN_CHAR + "{step}" # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) @@ -547,7 +546,7 @@ class ModelCheckpoint(Checkpoint): name = group[1:] if auto_insert_metric_name: - filename = filename.replace(group, name + cls.CHECKPOINT_EQUALS_CHAR + "{" + name) + filename = filename.replace(group, name + self.CHECKPOINT_EQUALS_CHAR + "{" + name) # support for dots: https://stackoverflow.com/a/7934969 filename = filename.replace(group, f"{{0[{name}]") @@ -557,7 +556,7 @@ class ModelCheckpoint(Checkpoint): filename = filename.format(metrics) if prefix: - filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) + filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) return filename diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index a7eb8b544a..81cc98cf50 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -402,34 +402,36 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch): + model_checkpoint = ModelCheckpoint(dirpath=tmpdir) + # empty filename: - ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2}) + ckpt_name = model_checkpoint._format_checkpoint_name("", {"epoch": 3, "step": 2}) assert ckpt_name == "epoch=3-step=2" - ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test") + ckpt_name = model_checkpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test") assert ckpt_name == "test-epoch=3-step=2" # no groups case: - ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test") + ckpt_name = model_checkpoint._format_checkpoint_name("ckpt", {}, prefix="test") assert ckpt_name == "test-ckpt" # no prefix - ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) + ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) assert ckpt_name == "epoch=003-acc=0.03" # one metric name is substring of another - ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3}) + ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3}) assert ckpt_name == "epoch=003-epoch_test=003" # prefix - monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@") - ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test") + model_checkpoint.CHECKPOINT_JOIN_CHAR = "@" + ckpt_name = model_checkpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test") assert ckpt_name == "test@epoch=3,acc=0.03000" monkeypatch.undo() # non-default char for equals sign - monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":") - ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) + model_checkpoint.CHECKPOINT_EQUALS_CHAR = ":" + ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) assert ckpt_name == "epoch:003-acc:0.03" monkeypatch.undo() @@ -454,13 +456,13 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch): assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt" # auto_insert_metric_name=False - ckpt_name = ModelCheckpoint._format_checkpoint_name( + ckpt_name = model_checkpoint._format_checkpoint_name( "epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False ) assert ckpt_name == "epoch=003-val_acc=0.03" # dots in the metric name - ckpt_name = ModelCheckpoint._format_checkpoint_name( + ckpt_name = model_checkpoint._format_checkpoint_name( "mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False ) assert ckpt_name == "mAP@0.50=0.2000"