Fixed a crash bug in MLFlow logger (#4716)

* warnings.warn doesn't accept tuples, which causes "TypeError: expected string or bytes-like object" when the execution flow gets to this warning. Fixed that.

* Try adding a mock test

* Try adding a mock test

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Peter Gagarinov 2020-11-24 08:50:34 +03:00 committed by GitHub
parent 471ca375ba
commit 70361ebb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 5 deletions

View File

@ -17,7 +17,6 @@ MLflow Logger
-------------
"""
import re
import warnings
from argparse import Namespace
from time import time
from typing import Any, Dict, Optional, Union
@ -32,7 +31,7 @@ except ModuleNotFoundError: # pragma: no-cover
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
LOCAL_FILE_URI_PREFIX = "file:"
@ -165,9 +164,11 @@ class MLFlowLogger(LightningLoggerBase):
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
if k != new_k:
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
f"Replacing {k} with {new_k}."))
k = new_k
rank_zero_warn(
"MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
f" Replacing {k} with {new_k}.", RuntimeWarning
)
k = new_k
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

View File

@ -150,8 +150,24 @@ def test_mlflow_logger_dirs_creation(tmpdir):
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
"""
Test that the logger experiment_id retrieved only once.
"""
logger = MLFlowLogger('test', save_dir=tmpdir)
_ = logger.experiment
_ = logger.experiment
_ = logger.experiment
assert logger.experiment.get_experiment_by_name.call_count == 1
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
"""
Test that the logger raises warning with special characters not accepted by MLFlow.
"""
logger = MLFlowLogger('test', save_dir=tmpdir)
metrics = {'[some_metric]': 10}
with pytest.warns(RuntimeWarning, match='special characters in metric name'):
logger.log_metrics(metrics)