fix logged keys in mlflow logger (#4412)
* [#4411] fix gpu_log_memory with mlflow logger * sanitize parenthesis instead of removing for all loggers * apply regex for mlflow key sanitization * replace ',' with '.' typo * add single warning and test Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
11415faade
commit
470e2945fc
|
@ -16,6 +16,8 @@
|
|||
MLflow
|
||||
------
|
||||
"""
|
||||
import re
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from time import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
@ -151,6 +153,13 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
if isinstance(v, str):
|
||||
log.warning(f'Discarding metric with string value {k}={v}.')
|
||||
continue
|
||||
|
||||
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
|
||||
if k != new_k:
|
||||
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
|
||||
f"Replacing {k} with {new_k}."))
|
||||
k = new_k
|
||||
|
||||
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
|
||||
|
||||
@rank_zero_only
|
||||
|
|
|
@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir):
|
|||
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3,
|
||||
log_gpu_memory=True)
|
||||
trainer.fit(model)
|
||||
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
|
||||
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
|
||||
|
|
Loading…
Reference in New Issue