diff --git a/CHANGELOG.md b/CHANGELOG.md index d21359729d..f6d870f284 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,7 +135,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716)) -- +- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783)) ## [1.6.1] - 2022-04-13 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ea23bd21f6..735bef6a63 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -517,9 +517,12 @@ class ModelCheckpoint(Callback): if auto_insert_metric_name: filename = filename.replace(group, name + "={" + name) + # support for dots: https://stackoverflow.com/a/7934969 + filename = filename.replace(group, f"{{0[{name}]") + if name not in metrics: metrics[name] = 0 - filename = filename.format(**metrics) + filename = filename.format(metrics) if prefix: filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9e112f81ad..48e6c9b294 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -452,6 +452,12 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ) assert ckpt_name == "epoch=003-val_acc=0.03" + # dots in the metric name + ckpt_name = ModelCheckpoint._format_checkpoint_name( + "mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False + ) + assert ckpt_name == "mAP@0.50=0.2000" + class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = ".tpkc"