Lazily import dependencies for MLFlowLogger (#18528)

This commit is contained in:
Adrian Wälchli 2023-09-12 06:17:57 -07:00 committed by GitHub
parent c959df74b8
commit 6ab443f9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 158 additions and 139 deletions

View File

@ -16,17 +16,14 @@ import os
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from lightning.pytorch.loggers.neptune import NeptuneLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.wandb import WandbLogger
__all__ = ["CSVLogger", "Logger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
__all__ = ["CSVLogger", "Logger", "MLFlowLogger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
if _COMET_AVAILABLE:
__all__.append("CometLogger")
# needed to prevent ModuleNotFoundError and duplicated logs.
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
if _MLFLOW_AVAILABLE:
__all__.append("MLFlowLogger")

View File

@ -22,7 +22,7 @@ import tempfile
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Union
import yaml
from lightning_utilities.core.imports import RequirementCache
@ -37,39 +37,6 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
if _MLFLOW_AVAILABLE:
import mlflow
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
else:
mlflow = None
MlflowClient, context = None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"
# before v1.1.0
if hasattr(context, "resolve_tags"):
from mlflow.tracking.context import resolve_tags
# since v1.1.0
elif hasattr(context, "registry"):
from mlflow.tracking.context.registry import resolve_tags
else:
def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]:
"""
Args:
tags: A dictionary of tags to override. If specified, tags passed in this argument will
override those inferred from the context.
Returns: A dictionary of resolved tags.
Note:
See ``mlflow.tracking.context.registry`` for more details.
"""
return tags
class MLFlowLogger(Logger):
@ -169,11 +136,13 @@ class MLFlowLogger(Logger):
self._initialized = False
from mlflow.tracking import MlflowClient
self._mlflow_client = MlflowClient(tracking_uri)
@property
@rank_zero_experiment
def experiment(self) -> MlflowClient:
def experiment(self) -> Any:
r"""
Actual MLflow object. To use MLflow features in your
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
@ -183,6 +152,8 @@ class MLFlowLogger(Logger):
self.logger.experiment.some_mlflow_function()
"""
import mlflow
if self._initialized:
return self._mlflow_client
@ -207,11 +178,16 @@ class MLFlowLogger(Logger):
if self._run_id is None:
if self._run_name is not None:
self.tags = self.tags or {}
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
if MLFLOW_RUN_NAME in self.tags:
log.warning(
f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
)
self.tags[MLFLOW_RUN_NAME] = self._run_name
resolve_tags = _get_resolve_tags()
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
self._initialized = True
@ -244,6 +220,8 @@ class MLFlowLogger(Logger):
params = _convert_params(params)
params = _flatten_dict(params)
from mlflow.entities import Param
# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
@ -256,6 +234,8 @@ class MLFlowLogger(Logger):
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
from mlflow.entities import Metric
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics_list: List[Metric] = []
@ -383,3 +363,18 @@ class MLFlowLogger(Logger):
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t
def _get_resolve_tags() -> Callable:
from mlflow.tracking import context
# before v1.1.0
if hasattr(context, "resolve_tags"):
from mlflow.tracking.context import resolve_tags
# since v1.1.0
elif hasattr(context, "registry"):
from mlflow.tracking.context.registry import resolve_tags
else:
resolve_tags = lambda tags: tags
return resolve_tags

View File

@ -0,0 +1,27 @@
import sys
from types import ModuleType
from unittest.mock import Mock
import pytest
@pytest.fixture()
def mlflow_mock(monkeypatch):
mlflow = ModuleType("mlflow")
mlflow.set_tracking_uri = Mock()
monkeypatch.setitem(sys.modules, "mlflow", mlflow)
mlflow_tracking = ModuleType("tracking")
mlflow_tracking.MlflowClient = Mock()
mlflow_tracking.artifact_utils = Mock()
monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking)
mlflow_entities = ModuleType("entities")
mlflow_entities.Metric = Mock()
mlflow_entities.Param = Mock()
mlflow_entities.time = Mock()
monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities)
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
return mlflow

View File

@ -42,8 +42,7 @@ LOGGER_CTX_MANAGERS = (
mock.patch("lightning.pytorch.loggers.comet.comet_ml"),
mock.patch("lightning.pytorch.loggers.comet.CometOfflineExperiment"),
mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"),
mock.patch("lightning.pytorch.loggers.mlflow.Metric"),
mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()),
mock.patch("lightning.pytorch.loggers.neptune.neptune", new_callable=create_neptune_mock),
mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock),
@ -82,10 +81,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
return logger_class(**args)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(tmpdir, monkeypatch, logger_class):
def test_loggers_fit_test_all(logger_class, mlflow_mock, tmpdir):
"""Verify that basic functionality of all loggers."""
with contextlib.ExitStack() as stack:
for mgr in LOGGER_CTX_MANAGERS:
@ -300,7 +298,7 @@ def _test_logger_initialization(tmpdir, logger_class):
trainer.fit(model)
def test_logger_with_prefix_all(tmpdir, monkeypatch):
def test_logger_with_prefix_all(mlflow_mock, monkeypatch, tmpdir):
"""Test that prefix is added at the beginning of the metric keys."""
prefix = "tmp"
@ -315,10 +313,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.Metric"
) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), mock.patch(
"lightning.pytorch.loggers.mlflow.mlflow"
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
):
Metric = mlflow_mock.entities.Metric
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_batch.assert_called_once_with(
@ -358,7 +355,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
def test_logger_default_name(tmpdir, monkeypatch):
def test_logger_default_name(mlflow_mock, monkeypatch, tmpdir):
"""Test that the default logger name is lightning_logs."""
# CSV
logger = CSVLogger(save_dir=tmpdir)
@ -376,9 +373,10 @@ def test_logger_default_name(tmpdir, monkeypatch):
# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.MlflowClient"
) as mlflow_client, mock.patch("lightning.pytorch.loggers.mlflow.mlflow"):
mlflow_client().get_experiment_by_name.return_value = None
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
):
client = mlflow_mock.tracking.MlflowClient()
client.get_experiment_by_name.return_value = None
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir)
_ = logger.experiment

View File

@ -19,8 +19,7 @@ import pytest
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
from lightning.pytorch.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags
from lightning.pytorch.loggers.mlflow import _get_resolve_tags, _MLFLOW_AVAILABLE, MLFlowLogger
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
@ -33,11 +32,12 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
return logger
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_exists(client, _, tmpdir):
def test_mlflow_logger_exists(_, mlflow_mock, tmp_path):
"""Test launching three independent loggers with either same or different experiment name."""
client = mlflow_mock.tracking.MlflowClient
run1 = MagicMock()
run1.info.run_id = "run-id-1"
run1.info.experiment_id = "exp-id-1"
@ -53,7 +53,7 @@ def test_mlflow_logger_exists(client, _, tmpdir):
client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id
client.return_value.create_run = MagicMock(return_value=run1)
logger = MLFlowLogger("test", save_dir=tmpdir)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
assert logger._experiment_id is None
assert logger._run_id is None
_ = logger.experiment
@ -69,7 +69,7 @@ def test_mlflow_logger_exists(client, _, tmpdir):
client.return_value.create_run = MagicMock(return_value=run2)
# same name leads to same experiment id, but different runs get recorded
logger2 = MLFlowLogger("test", save_dir=tmpdir)
logger2 = MLFlowLogger("test", save_dir=str(tmp_path))
assert logger2.experiment_id == logger.experiment_id
assert logger2.run_id == "run-id-2"
assert logger2.experiment.create_experiment.call_count == 0
@ -82,43 +82,49 @@ def test_mlflow_logger_exists(client, _, tmpdir):
client.return_value.create_run = MagicMock(return_value=run3)
# logger with new experiment name causes new experiment id and new run id to be created
logger3 = MLFlowLogger("new", save_dir=tmpdir)
logger3 = MLFlowLogger("new", save_dir=str(tmp_path))
assert logger3.experiment_id == "exp-id-3" != logger.experiment_id
assert logger3.run_id == "run-id-3"
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_run_name_setting(client, _, tmpdir):
def test_mlflow_run_name_setting(tmp_path):
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
if not _MLFLOW_AVAILABLE:
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
resolve_tags = _get_resolve_tags()
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
# run_name is appended to tags
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=tmpdir)
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path))
logger._mlflow_client = client = Mock()
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
client.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
# run_name overrides tags[MLFLOW_RUN_NAME]
logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir)
logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=str(tmp_path))
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
client.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
# default run_name (= None) does not append new tag
logger = MLFlowLogger("test", save_dir=tmpdir)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
_ = logger.experiment
default_tags = resolve_tags(None)
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_run_id_setting(client, _, tmpdir):
def test_mlflow_run_id_setting(_, mlflow_mock, tmp_path):
"""Test that the run_id argument uses the provided run_id."""
client = mlflow_mock.tracking.MlflowClient
run = MagicMock()
run.info.run_id = "run-id"
run.info.experiment_id = "experiment-id"
@ -127,7 +133,7 @@ def test_mlflow_run_id_setting(client, _, tmpdir):
client.return_value.get_run = MagicMock(return_value=run)
# run_id exists uses the existing run
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir)
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=str(tmp_path))
_ = logger.experiment
client.return_value.get_run.assert_called_with(run.info.run_id)
assert logger.experiment_id == run.info.experiment_id
@ -135,11 +141,12 @@ def test_mlflow_run_id_setting(client, _, tmpdir):
client.reset_mock(return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, _, tmpdir):
def test_mlflow_log_dir(_, mlflow_mock, tmp_path):
"""Test that the trainer saves checkpoints in the logger's save dir."""
client = mlflow_mock.tracking.MlflowClient
# simulate experiment creation with mlflow client mock
run = MagicMock()
run.info.run_id = "run-id"
@ -148,37 +155,39 @@ def test_mlflow_log_dir(client, _, tmpdir):
client.return_value.create_run = MagicMock(return_value=run)
# test construction of default log dir path
logger = MLFlowLogger("test", save_dir=tmpdir)
assert logger.save_dir == tmpdir
logger = MLFlowLogger("test", save_dir=str(tmp_path))
assert logger.save_dir == str(tmp_path)
assert logger.version == "run-id"
assert logger.name == "exp-id"
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3)
trainer = Trainer(
default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3
)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / "checkpoints")
assert trainer.checkpoint_callback.dirpath == str(tmp_path / "exp-id" / "run-id" / "checkpoints")
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=1.ckpt"}
assert trainer.log_dir == logger.save_dir
def test_mlflow_logger_dirs_creation(tmpdir):
def test_mlflow_logger_dirs_creation(tmp_path):
"""Test that the logger creates the folders and files in the right place."""
if not _MLFLOW_AVAILABLE:
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
assert not os.listdir(tmpdir)
logger = MLFlowLogger("test", save_dir=tmpdir)
assert logger.save_dir == tmpdir
assert set(os.listdir(tmpdir)) == {".trash"}
assert not os.listdir(tmp_path)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
assert logger.save_dir == str(tmp_path)
assert set(os.listdir(tmp_path)) == {".trash"}
run_id = logger.run_id
exp_id = logger.experiment_id
# multiple experiment calls should not lead to new experiment folders
for i in range(2):
_ = logger.experiment
assert set(os.listdir(tmpdir)) == {".trash", exp_id}
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
assert set(os.listdir(tmp_path)) == {".trash", exp_id}
assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"}
class CustomModel(BoringModel):
def on_train_epoch_end(self, *args, **kwargs):
@ -187,56 +196,53 @@ def test_mlflow_logger_dirs_creation(tmpdir):
model = CustomModel()
limit_batches = 5
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=tmp_path,
logger=logger,
max_epochs=1,
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
)
trainer.fit(model)
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
assert "epoch" in os.listdir(tmpdir / exp_id / run_id / "metrics")
assert set(os.listdir(tmpdir / exp_id / run_id / "params")) == model.hparams.keys()
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / "checkpoints")
assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"}
assert "epoch" in os.listdir(tmp_path / exp_id / run_id / "metrics")
assert set(os.listdir(tmp_path / exp_id / run_id / "params")) == model.hparams.keys()
assert trainer.checkpoint_callback.dirpath == str(tmp_path / exp_id / run_id / "checkpoints")
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
"""Test that the logger experiment_id retrieved only once."""
logger = MLFlowLogger("test", save_dir=tmpdir)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
_ = logger.experiment
_ = logger.experiment
_ = logger.experiment
assert logger.experiment.get_experiment_by_name.call_count == 1
@mock.patch("lightning.pytorch.loggers.mlflow.Metric")
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
def test_mlflow_logger_with_unexpected_characters(_, mlflow_mock, tmp_path):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
metrics = {"[some_metric]": 10}
with pytest.warns(RuntimeWarning, match="special characters in metric name"):
logger.log_metrics(metrics)
@mock.patch("lightning.pytorch.loggers.mlflow.Metric")
@mock.patch("lightning.pytorch.loggers.mlflow.Param")
@mock.patch("lightning.pytorch.loggers.mlflow.time")
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
def test_mlflow_logger_experiment_calls(_, mlflow_mock, tmp_path):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time = mlflow_mock.entities.time
metric = mlflow_mock.entities.Metric
param = mlflow_mock.entities.Param
time.return_value = 1
logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location")
logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location")
logger._mlflow_client.get_experiment_by_name.return_value = None
params = {"test": "test_param"}
@ -260,17 +266,17 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
)
def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250
@mock.patch("lightning.pytorch.loggers.mlflow.Param", side_effect=_check_value_length)
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
def test_mlflow_logger_with_long_param_value(_, mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
logger = MLFlowLogger("test", save_dir=tmpdir)
def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250
mlflow_mock.entities.Param.side_effect = _check_value_length
logger = MLFlowLogger("test", save_dir=str(tmp_path))
params = {"test": "test_param" * 50}
logger.log_hyperparams(params)
@ -279,13 +285,11 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
logger.experiment.log_batch.assert_called_once()
@mock.patch("lightning.pytorch.loggers.mlflow.Param")
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
logger = MLFlowLogger("test", save_dir=tmpdir)
def test_mlflow_logger_with_many_params(_, mlflow_mock, tmp_path):
"""Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
logger.log_hyperparams(params)
@ -301,10 +305,9 @@ def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
("finished", "FINISHED"),
],
)
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize(_, __, status, expected):
def test_mlflow_logger_finalize(_, mlflow_mock, status, expected):
logger = MLFlowLogger("test")
# Pretend we are in a worker process and finalizing
@ -315,10 +318,9 @@ def test_mlflow_logger_finalize(_, __, status, expected):
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected)
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize_when_exception(*_):
def test_mlflow_logger_finalize_when_exception(_, mlflow_mock):
logger = MLFlowLogger("test")
# Pretend we are on the main process and failing
@ -334,19 +336,20 @@ def test_mlflow_logger_finalize_when_exception(*_):
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
@pytest.mark.parametrize("log_model", ["all", True, False])
def test_mlflow_log_model(client, _, tmpdir, log_model):
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
def test_mlflow_log_model(_, mlflow_mock, log_model, tmp_path):
"""Test that the logger creates the folders and files in the right place."""
client = mlflow_mock.tracking.MlflowClient
# Get model, logger, trainer and train
model = BoringModel()
logger = MLFlowLogger("test", save_dir=tmpdir, log_model=log_model)
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=log_model)
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=tmp_path,
logger=logger,
max_epochs=2,
limit_train_batches=3,
@ -373,10 +376,9 @@ def test_mlflow_log_model(client, _, tmpdir, log_model):
assert not client.return_value.log_artifacts.called
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow")
def test_set_tracking_uri(mlflow_mock, *_):
def test_set_tracking_uri(_, mlflow_mock):
"""Test that the tracking uri is set for logging artifacts to MLFlow server."""
logger = MLFlowLogger(tracking_uri="the_tracking_uri")
mlflow_mock.set_tracking_uri.assert_not_called()