Two fixes for handling edge cases in MLflow logging (#16451)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Seppo Enarvi 2023-01-23 15:29:58 +02:00 committed by GitHub
parent 3611fcd152
commit 9346151359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 23 deletions

View File

@ -134,6 +134,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))
- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451))
## [1.9.0] - 2023-01-17

View File

@ -238,18 +238,14 @@ class MLFlowLogger(Logger):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params_list: List[Param] = []
for k, v in params.items():
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=k, value=v))
# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
self.experiment.log_batch(run_id=self.run_id, params=params_list)
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:

View File

@ -224,19 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
logger.log_metrics(metrics)
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
value = "test" * 100
key = "test_param"
params = {key: value}
with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"):
logger.log_hyperparams(params)
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow.time")
@ -270,6 +257,38 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
)
def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250
@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length)
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
"""Test that long parameter values are truncated to 250 characters."""
logger = MLFlowLogger("test", save_dir=tmpdir)
params = {"test": "test_param" * 50}
logger.log_hyperparams(params)
# assert_called_once_with() won't properly check the parameter value.
logger.experiment.log_batch.assert_called_once()
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100
parameters."""
logger = MLFlowLogger("test", save_dir=tmpdir)
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
logger.log_hyperparams(params)
assert logger.experiment.log_batch.call_count == 2
@pytest.mark.parametrize(
"status,expected",
[