diff --git a/src/lightning/pytorch/loggers/__init__.py b/src/lightning/pytorch/loggers/__init__.py index 0c3c454b73..44359e2ed8 100644 --- a/src/lightning/pytorch/loggers/__init__.py +++ b/src/lightning/pytorch/loggers/__init__.py @@ -16,17 +16,14 @@ import os from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401 from lightning.pytorch.loggers.csv_logs import CSVLogger from lightning.pytorch.loggers.logger import Logger -from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401 +from lightning.pytorch.loggers.mlflow import MLFlowLogger from lightning.pytorch.loggers.neptune import NeptuneLogger from lightning.pytorch.loggers.tensorboard import TensorBoardLogger from lightning.pytorch.loggers.wandb import WandbLogger -__all__ = ["CSVLogger", "Logger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"] +__all__ = ["CSVLogger", "Logger", "MLFlowLogger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"] if _COMET_AVAILABLE: __all__.append("CometLogger") # needed to prevent ModuleNotFoundError and duplicated logs. os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" - -if _MLFLOW_AVAILABLE: - __all__.append("MLFlowLogger") diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 49387ecbea..563cac7f40 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -22,7 +22,7 @@ import tempfile from argparse import Namespace from pathlib import Path from time import time -from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Union import yaml from lightning_utilities.core.imports import RequirementCache @@ -37,39 +37,6 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" _MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow") -if _MLFLOW_AVAILABLE: - import mlflow - from mlflow.entities import Metric, Param - from mlflow.tracking import context, MlflowClient - from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME -else: - mlflow = None - MlflowClient, context = None, None - Metric, Param = None, None - MLFLOW_RUN_NAME = "mlflow.runName" - -# before v1.1.0 -if hasattr(context, "resolve_tags"): - from mlflow.tracking.context import resolve_tags - - -# since v1.1.0 -elif hasattr(context, "registry"): - from mlflow.tracking.context.registry import resolve_tags -else: - - def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]: - """ - Args: - tags: A dictionary of tags to override. If specified, tags passed in this argument will - override those inferred from the context. - - Returns: A dictionary of resolved tags. - - Note: - See ``mlflow.tracking.context.registry`` for more details. - """ - return tags class MLFlowLogger(Logger): @@ -169,11 +136,13 @@ class MLFlowLogger(Logger): self._initialized = False + from mlflow.tracking import MlflowClient + self._mlflow_client = MlflowClient(tracking_uri) @property @rank_zero_experiment - def experiment(self) -> MlflowClient: + def experiment(self) -> Any: r""" Actual MLflow object. To use MLflow features in your :class:`~lightning.pytorch.core.module.LightningModule` do the following. @@ -183,6 +152,8 @@ class MLFlowLogger(Logger): self.logger.experiment.some_mlflow_function() """ + import mlflow + if self._initialized: return self._mlflow_client @@ -207,11 +178,16 @@ class MLFlowLogger(Logger): if self._run_id is None: if self._run_name is not None: self.tags = self.tags or {} + + from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME + if MLFLOW_RUN_NAME in self.tags: log.warning( f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}." ) self.tags[MLFLOW_RUN_NAME] = self._run_name + + resolve_tags = _get_resolve_tags() run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id self._initialized = True @@ -244,6 +220,8 @@ class MLFlowLogger(Logger): params = _convert_params(params) params = _flatten_dict(params) + from mlflow.entities import Param + # Truncate parameter values to 250 characters. # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] @@ -256,6 +234,8 @@ class MLFlowLogger(Logger): def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + from mlflow.entities import Metric + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) metrics_list: List[Metric] = [] @@ -383,3 +363,18 @@ class MLFlowLogger(Logger): # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t + + +def _get_resolve_tags() -> Callable: + from mlflow.tracking import context + + # before v1.1.0 + if hasattr(context, "resolve_tags"): + from mlflow.tracking.context import resolve_tags + # since v1.1.0 + elif hasattr(context, "registry"): + from mlflow.tracking.context.registry import resolve_tags + else: + resolve_tags = lambda tags: tags + + return resolve_tags diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py new file mode 100644 index 0000000000..83eff760ea --- /dev/null +++ b/tests/tests_pytorch/loggers/conftest.py @@ -0,0 +1,27 @@ +import sys +from types import ModuleType +from unittest.mock import Mock + +import pytest + + +@pytest.fixture() +def mlflow_mock(monkeypatch): + mlflow = ModuleType("mlflow") + mlflow.set_tracking_uri = Mock() + monkeypatch.setitem(sys.modules, "mlflow", mlflow) + + mlflow_tracking = ModuleType("tracking") + mlflow_tracking.MlflowClient = Mock() + mlflow_tracking.artifact_utils = Mock() + monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking) + + mlflow_entities = ModuleType("entities") + mlflow_entities.Metric = Mock() + mlflow_entities.Param = Mock() + mlflow_entities.time = Mock() + monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities) + + mlflow.tracking = mlflow_tracking + mlflow.entities = mlflow_entities + return mlflow diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 7ef5bdba2c..118143f1da 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -42,8 +42,7 @@ LOGGER_CTX_MANAGERS = ( mock.patch("lightning.pytorch.loggers.comet.comet_ml"), mock.patch("lightning.pytorch.loggers.comet.CometOfflineExperiment"), mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), - mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), - mock.patch("lightning.pytorch.loggers.mlflow.Metric"), + mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()), mock.patch("lightning.pytorch.loggers.neptune.neptune", new_callable=create_neptune_mock), mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True), mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock), @@ -82,10 +81,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs): return logger_class(**args) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) @mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True) @pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES) -def test_loggers_fit_test_all(tmpdir, monkeypatch, logger_class): +def test_loggers_fit_test_all(logger_class, mlflow_mock, tmpdir): """Verify that basic functionality of all loggers.""" with contextlib.ExitStack() as stack: for mgr in LOGGER_CTX_MANAGERS: @@ -300,7 +298,7 @@ def _test_logger_initialization(tmpdir, logger_class): trainer.fit(model) -def test_logger_with_prefix_all(tmpdir, monkeypatch): +def test_logger_with_prefix_all(mlflow_mock, monkeypatch, tmpdir): """Test that prefix is added at the beginning of the metric keys.""" prefix = "tmp" @@ -315,10 +313,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): # MLflow with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch( - "lightning.pytorch.loggers.mlflow.Metric" - ) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), mock.patch( - "lightning.pytorch.loggers.mlflow.mlflow" + "lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock() ): + Metric = mlflow_mock.entities.Metric logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log_batch.assert_called_once_with( @@ -358,7 +355,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0}) -def test_logger_default_name(tmpdir, monkeypatch): +def test_logger_default_name(mlflow_mock, monkeypatch, tmpdir): """Test that the default logger name is lightning_logs.""" # CSV logger = CSVLogger(save_dir=tmpdir) @@ -376,9 +373,10 @@ def test_logger_default_name(tmpdir, monkeypatch): # MLflow with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch( - "lightning.pytorch.loggers.mlflow.MlflowClient" - ) as mlflow_client, mock.patch("lightning.pytorch.loggers.mlflow.mlflow"): - mlflow_client().get_experiment_by_name.return_value = None + "lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock() + ): + client = mlflow_mock.tracking.MlflowClient() + client.get_experiment_by_name.return_value = None logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir) _ = logger.experiment diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 77b311b9b3..929e9f772f 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -19,8 +19,7 @@ import pytest from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loggers import _MLFLOW_AVAILABLE, MLFlowLogger -from lightning.pytorch.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags +from lightning.pytorch.loggers.mlflow import _get_resolve_tags, _MLFLOW_AVAILABLE, MLFlowLogger def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): @@ -33,11 +32,12 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r return logger +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_exists(client, _, tmpdir): +def test_mlflow_logger_exists(_, mlflow_mock, tmp_path): """Test launching three independent loggers with either same or different experiment name.""" + client = mlflow_mock.tracking.MlflowClient + run1 = MagicMock() run1.info.run_id = "run-id-1" run1.info.experiment_id = "exp-id-1" @@ -53,7 +53,7 @@ def test_mlflow_logger_exists(client, _, tmpdir): client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id client.return_value.create_run = MagicMock(return_value=run1) - logger = MLFlowLogger("test", save_dir=tmpdir) + logger = MLFlowLogger("test", save_dir=str(tmp_path)) assert logger._experiment_id is None assert logger._run_id is None _ = logger.experiment @@ -69,7 +69,7 @@ def test_mlflow_logger_exists(client, _, tmpdir): client.return_value.create_run = MagicMock(return_value=run2) # same name leads to same experiment id, but different runs get recorded - logger2 = MLFlowLogger("test", save_dir=tmpdir) + logger2 = MLFlowLogger("test", save_dir=str(tmp_path)) assert logger2.experiment_id == logger.experiment_id assert logger2.run_id == "run-id-2" assert logger2.experiment.create_experiment.call_count == 0 @@ -82,43 +82,49 @@ def test_mlflow_logger_exists(client, _, tmpdir): client.return_value.create_run = MagicMock(return_value=run3) # logger with new experiment name causes new experiment id and new run id to be created - logger3 = MLFlowLogger("new", save_dir=tmpdir) + logger3 = MLFlowLogger("new", save_dir=str(tmp_path)) assert logger3.experiment_id == "exp-id-3" != logger.experiment_id assert logger3.run_id == "run-id-3" -@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_run_name_setting(client, _, tmpdir): +def test_mlflow_run_name_setting(tmp_path): """Test that the run_name argument makes the MLFLOW_RUN_NAME tag.""" + if not _MLFLOW_AVAILABLE: + pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") + + from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME + + resolve_tags = _get_resolve_tags() tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"}) # run_name is appended to tags - logger = MLFlowLogger("test", run_name="run-name-1", save_dir=tmpdir) + logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path)) + logger._mlflow_client = client = Mock() + logger = mock_mlflow_run_creation(logger, experiment_id="exp-id") _ = logger.experiment - client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags) + client.create_run.assert_called_with(experiment_id="exp-id", tags=tags) # run_name overrides tags[MLFLOW_RUN_NAME] - logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir) + logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=str(tmp_path)) logger = mock_mlflow_run_creation(logger, experiment_id="exp-id") _ = logger.experiment - client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags) + client.create_run.assert_called_with(experiment_id="exp-id", tags=tags) # default run_name (= None) does not append new tag - logger = MLFlowLogger("test", save_dir=tmpdir) + logger = MLFlowLogger("test", save_dir=str(tmp_path)) logger = mock_mlflow_run_creation(logger, experiment_id="exp-id") _ = logger.experiment default_tags = resolve_tags(None) - client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) + client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_run_id_setting(client, _, tmpdir): +def test_mlflow_run_id_setting(_, mlflow_mock, tmp_path): """Test that the run_id argument uses the provided run_id.""" + client = mlflow_mock.tracking.MlflowClient + run = MagicMock() run.info.run_id = "run-id" run.info.experiment_id = "experiment-id" @@ -127,7 +133,7 @@ def test_mlflow_run_id_setting(client, _, tmpdir): client.return_value.get_run = MagicMock(return_value=run) # run_id exists uses the existing run - logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir) + logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=str(tmp_path)) _ = logger.experiment client.return_value.get_run.assert_called_with(run.info.run_id) assert logger.experiment_id == run.info.experiment_id @@ -135,11 +141,12 @@ def test_mlflow_run_id_setting(client, _, tmpdir): client.reset_mock(return_value=True) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_log_dir(client, _, tmpdir): +def test_mlflow_log_dir(_, mlflow_mock, tmp_path): """Test that the trainer saves checkpoints in the logger's save dir.""" + client = mlflow_mock.tracking.MlflowClient + # simulate experiment creation with mlflow client mock run = MagicMock() run.info.run_id = "run-id" @@ -148,37 +155,39 @@ def test_mlflow_log_dir(client, _, tmpdir): client.return_value.create_run = MagicMock(return_value=run) # test construction of default log dir path - logger = MLFlowLogger("test", save_dir=tmpdir) - assert logger.save_dir == tmpdir + logger = MLFlowLogger("test", save_dir=str(tmp_path)) + assert logger.save_dir == str(tmp_path) assert logger.version == "run-id" assert logger.name == "exp-id" model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3) + trainer = Trainer( + default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3 + ) assert trainer.log_dir == logger.save_dir trainer.fit(model) - assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / "checkpoints") + assert trainer.checkpoint_callback.dirpath == str(tmp_path / "exp-id" / "run-id" / "checkpoints") assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=1.ckpt"} assert trainer.log_dir == logger.save_dir -def test_mlflow_logger_dirs_creation(tmpdir): +def test_mlflow_logger_dirs_creation(tmp_path): """Test that the logger creates the folders and files in the right place.""" if not _MLFLOW_AVAILABLE: pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") - assert not os.listdir(tmpdir) - logger = MLFlowLogger("test", save_dir=tmpdir) - assert logger.save_dir == tmpdir - assert set(os.listdir(tmpdir)) == {".trash"} + assert not os.listdir(tmp_path) + logger = MLFlowLogger("test", save_dir=str(tmp_path)) + assert logger.save_dir == str(tmp_path) + assert set(os.listdir(tmp_path)) == {".trash"} run_id = logger.run_id exp_id = logger.experiment_id # multiple experiment calls should not lead to new experiment folders for i in range(2): _ = logger.experiment - assert set(os.listdir(tmpdir)) == {".trash", exp_id} - assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"} + assert set(os.listdir(tmp_path)) == {".trash", exp_id} + assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"} class CustomModel(BoringModel): def on_train_epoch_end(self, *args, **kwargs): @@ -187,56 +196,53 @@ def test_mlflow_logger_dirs_creation(tmpdir): model = CustomModel() limit_batches = 5 trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=limit_batches, limit_val_batches=limit_batches, ) trainer.fit(model) - assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"} - assert "epoch" in os.listdir(tmpdir / exp_id / run_id / "metrics") - assert set(os.listdir(tmpdir / exp_id / run_id / "params")) == model.hparams.keys() - assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / "checkpoints") + assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"} + assert "epoch" in os.listdir(tmp_path / exp_id / run_id / "metrics") + assert set(os.listdir(tmp_path / exp_id / run_id / "params")) == model.hparams.keys() + assert trainer.checkpoint_callback.dirpath == str(tmp_path / exp_id / run_id / "checkpoints") assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_experiment_id_retrieved_once(client, tmpdir): +def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path): """Test that the logger experiment_id retrieved only once.""" - logger = MLFlowLogger("test", save_dir=tmpdir) + logger = MLFlowLogger("test", save_dir=str(tmp_path)) _ = logger.experiment _ = logger.experiment _ = logger.experiment assert logger.experiment.get_experiment_by_name.call_count == 1 -@mock.patch("lightning.pytorch.loggers.mlflow.Metric") +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): +def test_mlflow_logger_with_unexpected_characters(_, mlflow_mock, tmp_path): """Test that the logger raises warning with special characters not accepted by MLFlow.""" - logger = MLFlowLogger("test", save_dir=tmpdir) + logger = MLFlowLogger("test", save_dir=str(tmp_path)) metrics = {"[some_metric]": 10} with pytest.warns(RuntimeWarning, match="special characters in metric name"): logger.log_metrics(metrics) -@mock.patch("lightning.pytorch.loggers.mlflow.Metric") -@mock.patch("lightning.pytorch.loggers.mlflow.Param") -@mock.patch("lightning.pytorch.loggers.mlflow.time") +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): +def test_mlflow_logger_experiment_calls(_, mlflow_mock, tmp_path): """Test that the logger calls methods on the mlflow experiment correctly.""" + time = mlflow_mock.entities.time + metric = mlflow_mock.entities.Metric + param = mlflow_mock.entities.Param + time.return_value = 1 - logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location") + logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location") logger._mlflow_client.get_experiment_by_name.return_value = None params = {"test": "test_param"} @@ -260,17 +266,17 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): ) -def _check_value_length(value, *args, **kwargs): - assert len(value) <= 250 - - -@mock.patch("lightning.pytorch.loggers.mlflow.Param", side_effect=_check_value_length) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): +def test_mlflow_logger_with_long_param_value(_, mlflow_mock, tmp_path): """Test that long parameter values are truncated to 250 characters.""" - logger = MLFlowLogger("test", save_dir=tmpdir) + + def _check_value_length(value, *args, **kwargs): + assert len(value) <= 250 + + mlflow_mock.entities.Param.side_effect = _check_value_length + + logger = MLFlowLogger("test", save_dir=str(tmp_path)) params = {"test": "test_param" * 50} logger.log_hyperparams(params) @@ -279,13 +285,11 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): logger.experiment.log_batch.assert_called_once() -@mock.patch("lightning.pytorch.loggers.mlflow.Param") +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_many_params(client, _, param, tmpdir): - """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" - logger = MLFlowLogger("test", save_dir=tmpdir) +def test_mlflow_logger_with_many_params(_, mlflow_mock, tmp_path): + """Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" + logger = MLFlowLogger("test", save_dir=str(tmp_path)) params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)} logger.log_hyperparams(params) @@ -301,10 +305,9 @@ def test_mlflow_logger_with_many_params(client, _, param, tmpdir): ("finished", "FINISHED"), ], ) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_finalize(_, __, status, expected): +def test_mlflow_logger_finalize(_, mlflow_mock, status, expected): logger = MLFlowLogger("test") # Pretend we are in a worker process and finalizing @@ -315,10 +318,9 @@ def test_mlflow_logger_finalize(_, __, status, expected): logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -def test_mlflow_logger_finalize_when_exception(*_): +def test_mlflow_logger_finalize_when_exception(_, mlflow_mock): logger = MLFlowLogger("test") # Pretend we are on the main process and failing @@ -334,19 +336,20 @@ def test_mlflow_logger_finalize_when_exception(*_): logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED") -@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") @pytest.mark.parametrize("log_model", ["all", True, False]) -def test_mlflow_log_model(client, _, tmpdir, log_model): +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +def test_mlflow_log_model(_, mlflow_mock, log_model, tmp_path): """Test that the logger creates the folders and files in the right place.""" + client = mlflow_mock.tracking.MlflowClient + # Get model, logger, trainer and train model = BoringModel() - logger = MLFlowLogger("test", save_dir=tmpdir, log_model=log_model) + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=log_model) logger = mock_mlflow_run_creation(logger, experiment_id="test-id") trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, logger=logger, max_epochs=2, limit_train_batches=3, @@ -373,10 +376,9 @@ def test_mlflow_log_model(client, _, tmpdir, log_model): assert not client.return_value.log_artifacts.called +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") -@mock.patch("lightning.pytorch.loggers.mlflow.mlflow") -def test_set_tracking_uri(mlflow_mock, *_): +def test_set_tracking_uri(_, mlflow_mock): """Test that the tracking uri is set for logging artifacts to MLFlow server.""" logger = MLFlowLogger(tracking_uri="the_tracking_uri") mlflow_mock.set_tracking_uri.assert_not_called()