Fixed a crash bug in MLFlow logger (#4716)
* warnings.warn doesn't accept tuples, which causes "TypeError: expected string or bytes-like object" when the execution flow gets to this warning. Fixed that. * Try adding a mock test * Try adding a mock test Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
471ca375ba
commit
70361ebb6d
|
@ -17,7 +17,6 @@ MLflow Logger
|
|||
-------------
|
||||
"""
|
||||
import re
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from time import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
@ -32,7 +31,7 @@ except ModuleNotFoundError: # pragma: no-cover
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
||||
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
||||
|
@ -165,9 +164,11 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
|
||||
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
|
||||
if k != new_k:
|
||||
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
|
||||
f"Replacing {k} with {new_k}."))
|
||||
k = new_k
|
||||
rank_zero_warn(
|
||||
"MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
|
||||
f" Replacing {k} with {new_k}.", RuntimeWarning
|
||||
)
|
||||
k = new_k
|
||||
|
||||
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
|
||||
|
||||
|
|
|
@ -150,8 +150,24 @@ def test_mlflow_logger_dirs_creation(tmpdir):
|
|||
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
|
||||
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
|
||||
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
|
||||
"""
|
||||
Test that the logger experiment_id retrieved only once.
|
||||
"""
|
||||
logger = MLFlowLogger('test', save_dir=tmpdir)
|
||||
_ = logger.experiment
|
||||
_ = logger.experiment
|
||||
_ = logger.experiment
|
||||
assert logger.experiment.get_experiment_by_name.call_count == 1
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
|
||||
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
|
||||
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
|
||||
"""
|
||||
Test that the logger raises warning with special characters not accepted by MLFlow.
|
||||
"""
|
||||
logger = MLFlowLogger('test', save_dir=tmpdir)
|
||||
metrics = {'[some_metric]': 10}
|
||||
|
||||
with pytest.warns(RuntimeWarning, match='special characters in metric name'):
|
||||
logger.log_metrics(metrics)
|
||||
|
|
Loading…
Reference in New Issue