From 93461513596485af2af14c60a438d48669a0a2e0 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 23 Jan 2023 15:29:58 +0200 Subject: [PATCH] Two fixes for handling edge cases in MLflow logging (#16451) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/pytorch_lightning/CHANGELOG.md | 2 + src/pytorch_lightning/loggers/mlflow.py | 16 +++----- tests/tests_pytorch/loggers/test_mlflow.py | 45 +++++++++++++++------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index a4d5a7d12f..e8e6431aa3 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -134,6 +134,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418)) +- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451)) + ## [1.9.0] - 2023-01-17 diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 980d4e4bcc..4b1088a6f4 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -238,18 +238,14 @@ class MLFlowLogger(Logger): def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) - params_list: List[Param] = [] - for k, v in params.items(): - # TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - if len(str(v)) > 250: - rank_zero_warn( - f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning - ) - continue - params_list.append(Param(key=k, value=v)) + # Truncate parameter values to 250 characters. + # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 + params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] - self.experiment.log_batch(run_id=self.run_id, params=params_list) + # Log in chunks of 100 parameters (the maximum allowed by MLflow). + for idx in range(0, len(params_list), 100): + self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 23de563270..14879ed104 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -224,19 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): logger.log_metrics(metrics) -@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_long_param_value(client, _, tmpdir): - """Test that the logger raises warning with special characters not accepted by MLFlow.""" - logger = MLFlowLogger("test", save_dir=tmpdir) - value = "test" * 100 - key = "test_param" - params = {key: value} - - with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"): - logger.log_hyperparams(params) - - @mock.patch("pytorch_lightning.loggers.mlflow.Metric") @mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow.time") @@ -270,6 +257,38 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): ) +def _check_value_length(value, *args, **kwargs): + assert len(value) <= 250 + + +@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length) +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): + """Test that long parameter values are truncated to 250 characters.""" + logger = MLFlowLogger("test", save_dir=tmpdir) + + params = {"test": "test_param" * 50} + logger.log_hyperparams(params) + + # assert_called_once_with() won't properly check the parameter value. + logger.experiment.log_batch.assert_called_once() + + +@mock.patch("pytorch_lightning.loggers.mlflow.Param") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_with_many_params(client, _, param, tmpdir): + """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 + parameters.""" + logger = MLFlowLogger("test", save_dir=tmpdir) + + params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)} + logger.log_hyperparams(params) + + assert logger.experiment.log_batch.call_count == 2 + + @pytest.mark.parametrize( "status,expected", [