Two fixes for handling edge cases in MLflow logging (#16451)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3611fcd152
commit
9346151359
|
@ -134,6 +134,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))
|
||||
|
||||
- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451))
|
||||
|
||||
|
||||
|
||||
## [1.9.0] - 2023-01-17
|
||||
|
|
|
@ -238,18 +238,14 @@ class MLFlowLogger(Logger):
|
|||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
params_list: List[Param] = []
|
||||
|
||||
for k, v in params.items():
|
||||
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
|
||||
if len(str(v)) > 250:
|
||||
rank_zero_warn(
|
||||
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
|
||||
)
|
||||
continue
|
||||
params_list.append(Param(key=k, value=v))
|
||||
# Truncate parameter values to 250 characters.
|
||||
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
|
||||
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
|
||||
|
||||
self.experiment.log_batch(run_id=self.run_id, params=params_list)
|
||||
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
|
||||
for idx in range(0, len(params_list), 100):
|
||||
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
|
||||
|
|
|
@ -224,19 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
|
|||
logger.log_metrics(metrics)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
|
||||
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
value = "test" * 100
|
||||
key = "test_param"
|
||||
params = {key: value}
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"):
|
||||
logger.log_hyperparams(params)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.time")
|
||||
|
@ -270,6 +257,38 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
|
|||
)
|
||||
|
||||
|
||||
def _check_value_length(value, *args, **kwargs):
|
||||
assert len(value) <= 250
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length)
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
|
||||
"""Test that long parameter values are truncated to 250 characters."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
|
||||
params = {"test": "test_param" * 50}
|
||||
logger.log_hyperparams(params)
|
||||
|
||||
# assert_called_once_with() won't properly check the parameter value.
|
||||
logger.experiment.log_batch.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
|
||||
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100
|
||||
parameters."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
|
||||
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
|
||||
logger.log_hyperparams(params)
|
||||
|
||||
assert logger.experiment.log_batch.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status,expected",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue