fix logged keys in mlflow logger (#4412)

* [#4411] fix gpu_log_memory with mlflow logger

* sanitize parenthesis instead of removing for all loggers

* apply regex for mlflow key sanitization

* replace ',' with '.' typo

* add single warning and test

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Diedre Carmo 2020-11-10 08:50:25 -03:00 committed by GitHub
parent 11415faade
commit 470e2945fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -16,6 +16,8 @@
MLflow
------
"""
import re
import warnings
from argparse import Namespace
from time import time
from typing import Any, Dict, Optional, Union
@ -151,6 +153,13 @@ class MLFlowLogger(LightningLoggerBase):
if isinstance(v, str):
log.warning(f'Discarding metric with string value {k}={v}.')
continue
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
if k != new_k:
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
f"Replacing {k} with {new_k}."))
k = new_k
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
@rank_zero_only

View File

@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3,
log_gpu_memory=True)
trainer.fit(model)
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')