From 70361ebb6dd2b387e8c9350266467b6ade042fa5 Mon Sep 17 00:00:00 2001 From: Peter Gagarinov Date: Tue, 24 Nov 2020 08:50:34 +0300 Subject: [PATCH] Fixed a crash bug in MLFlow logger (#4716) * warnings.warn doesn't accept tuples, which causes "TypeError: expected string or bytes-like object" when the execution flow gets to this warning. Fixed that. * Try adding a mock test * Try adding a mock test Co-authored-by: rohitgr7 Co-authored-by: chaton --- pytorch_lightning/loggers/mlflow.py | 11 ++++++----- tests/loggers/test_mlflow.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 1a908899a7..92f1c15d58 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -17,7 +17,6 @@ MLflow Logger ------------- """ import re -import warnings from argparse import Namespace from time import time from typing import Any, Dict, Optional, Union @@ -32,7 +31,7 @@ except ModuleNotFoundError: # pragma: no-cover from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn LOCAL_FILE_URI_PREFIX = "file:" @@ -165,9 +164,11 @@ class MLFlowLogger(LightningLoggerBase): new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) if k != new_k: - warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n", - f"Replacing {k} with {new_k}.")) - k = new_k + rank_zero_warn( + "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name." + f" Replacing {k} with {new_k}.", RuntimeWarning + ) + k = new_k self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index b220074d41..c52dd82889 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -150,8 +150,24 @@ def test_mlflow_logger_dirs_creation(tmpdir): @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): + """ + Test that the logger experiment_id retrieved only once. + """ logger = MLFlowLogger('test', save_dir=tmpdir) _ = logger.experiment _ = logger.experiment _ = logger.experiment assert logger.experiment.get_experiment_by_name.call_count == 1 + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): + """ + Test that the logger raises warning with special characters not accepted by MLFlow. + """ + logger = MLFlowLogger('test', save_dir=tmpdir) + metrics = {'[some_metric]': 10} + + with pytest.warns(RuntimeWarning, match='special characters in metric name'): + logger.log_metrics(metrics)