From 470e2945fc0fefa1b6d42f1d64fef9f22f5c75b3 Mon Sep 17 00:00:00 2001 From: Diedre Carmo Date: Tue, 10 Nov 2020 08:50:25 -0300 Subject: [PATCH] fix logged keys in mlflow logger (#4412) * [#4411] fix gpu_log_memory with mlflow logger * sanitize parenthesis instead of removing for all loggers * apply regex for mlflow key sanitization * replace ',' with '.' typo * add single warning and test Co-authored-by: Rohit Gupta Co-authored-by: chaton --- pytorch_lightning/loggers/mlflow.py | 9 +++++++++ tests/loggers/test_mlflow.py | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index de915785dc..ee9f8f86cf 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -16,6 +16,8 @@ MLflow ------ """ +import re +import warnings from argparse import Namespace from time import time from typing import Any, Dict, Optional, Union @@ -151,6 +153,13 @@ class MLFlowLogger(LightningLoggerBase): if isinstance(v, str): log.warning(f'Discarding metric with string value {k}={v}.') continue + + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) + if k != new_k: + warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n", + f"Replacing {k} with {new_k}.")) + k = new_k + self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) @rank_zero_only diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index a200fbf549..b220074d41 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} model = EvalModelTemplate() - trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, + log_gpu_memory=True) trainer.fit(model) assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')