Fix support for `ModelCheckpoint` monitors with dots (#12783)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
54a2b5ceeb
commit
b155a6323f
|
@ -135,7 +135,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))
|
||||
|
||||
|
||||
-
|
||||
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))
|
||||
|
||||
|
||||
## [1.6.1] - 2022-04-13
|
||||
|
|
|
@ -517,9 +517,12 @@ class ModelCheckpoint(Callback):
|
|||
if auto_insert_metric_name:
|
||||
filename = filename.replace(group, name + "={" + name)
|
||||
|
||||
# support for dots: https://stackoverflow.com/a/7934969
|
||||
filename = filename.replace(group, f"{{0[{name}]")
|
||||
|
||||
if name not in metrics:
|
||||
metrics[name] = 0
|
||||
filename = filename.format(**metrics)
|
||||
filename = filename.format(metrics)
|
||||
|
||||
if prefix:
|
||||
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
|
||||
|
|
|
@ -452,6 +452,12 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
|
|||
)
|
||||
assert ckpt_name == "epoch=003-val_acc=0.03"
|
||||
|
||||
# dots in the metric name
|
||||
ckpt_name = ModelCheckpoint._format_checkpoint_name(
|
||||
"mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False
|
||||
)
|
||||
assert ckpt_name == "mAP@0.50=0.2000"
|
||||
|
||||
|
||||
class ModelCheckpointExtensionTest(ModelCheckpoint):
|
||||
FILE_EXTENSION = ".tpkc"
|
||||
|
|
Loading…
Reference in New Issue