Fix support for `ModelCheckpoint` monitors with dots (#12783)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Henry Lau 2022-04-22 04:59:32 +08:00 committed by GitHub
parent 54a2b5ceeb
commit b155a6323f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 2 deletions

View File

@ -135,7 +135,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))
-
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))
## [1.6.1] - 2022-04-13

View File

@ -517,9 +517,12 @@ class ModelCheckpoint(Callback):
if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)
# support for dots: https://stackoverflow.com/a/7934969
filename = filename.replace(group, f"{{0[{name}]")
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
filename = filename.format(metrics)
if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])

View File

@ -452,6 +452,12 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
)
assert ckpt_name == "epoch=003-val_acc=0.03"
# dots in the metric name
ckpt_name = ModelCheckpoint._format_checkpoint_name(
"mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False
)
assert ckpt_name == "mAP@0.50=0.2000"
class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = ".tpkc"