Solved minor bug with MLFlow logger (#16418)

Resolves https://github.com/Lightning-AI/lightning/issues/16411
This commit is contained in:
Peutlefaire 2023-01-20 01:15:32 +01:00 committed by GitHub
parent d3de5c64d7
commit 6fd914f40b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 2 deletions

View File

@ -132,6 +132,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369)) - Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369))
- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))
## [1.9.0] - 2023-01-17 ## [1.9.0] - 2023-01-17

View File

@ -247,7 +247,7 @@ class MLFlowLogger(Logger):
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
) )
continue continue
params_list.append(Param(key=v, value=v)) params_list.append(Param(key=k, value=v))
self.experiment.log_batch(run_id=self.run_id, params=params_list) self.experiment.log_batch(run_id=self.run_id, params=params_list)

View File

@ -253,8 +253,9 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
logger.log_hyperparams(params) logger.log_hyperparams(params)
logger.experiment.log_batch.assert_called_once_with( logger.experiment.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test_param", value="test_param")] run_id=logger.run_id, params=[param(key="test", value="test_param")]
) )
param.assert_called_with(key="test", value="test_param")
metrics = {"some_metric": 10} metrics = {"some_metric": 10}
logger.log_metrics(metrics) logger.log_metrics(metrics)
@ -262,6 +263,7 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
logger.experiment.log_batch.assert_called_with( logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)] run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
) )
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)
logger._mlflow_client.create_experiment.assert_called_once_with( logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location" name="test", artifact_location="my_artifact_location"