From b155a6323f1187abbb6535d2bdf69aadc210432f Mon Sep 17 00:00:00 2001 From: Henry Lau <70014887+HenryLau0220@users.noreply.github.com> Date: Fri, 22 Apr 2022 04:59:32 +0800 Subject: [PATCH] Fix support for `ModelCheckpoint` monitors with dots (#12783) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 5 ++++- tests/checkpointing/test_model_checkpoint.py | 6 ++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d21359729d..f6d870f284 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,7 +135,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716)) -- +- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783)) ## [1.6.1] - 2022-04-13 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ea23bd21f6..735bef6a63 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -517,9 +517,12 @@ class ModelCheckpoint(Callback): if auto_insert_metric_name: filename = filename.replace(group, name + "={" + name) + # support for dots: https://stackoverflow.com/a/7934969 + filename = filename.replace(group, f"{{0[{name}]") + if name not in metrics: metrics[name] = 0 - filename = filename.format(**metrics) + filename = filename.format(metrics) if prefix: filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9e112f81ad..48e6c9b294 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -452,6 +452,12 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ) assert ckpt_name == "epoch=003-val_acc=0.03" + # dots in the metric name + ckpt_name = ModelCheckpoint._format_checkpoint_name( + "mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False + ) + assert ckpt_name == "mAP@0.50=0.2000" + class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = ".tpkc"