Lazily import dependencies for MLFlowLogger (#18528)
This commit is contained in:
parent
c959df74b8
commit
6ab443f9c8
|
@ -16,17 +16,14 @@ import os
|
|||
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
||||
from lightning.pytorch.loggers.logger import Logger
|
||||
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.mlflow import MLFlowLogger
|
||||
from lightning.pytorch.loggers.neptune import NeptuneLogger
|
||||
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
||||
from lightning.pytorch.loggers.wandb import WandbLogger
|
||||
|
||||
__all__ = ["CSVLogger", "Logger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
|
||||
__all__ = ["CSVLogger", "Logger", "MLFlowLogger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
|
||||
|
||||
if _COMET_AVAILABLE:
|
||||
__all__.append("CometLogger")
|
||||
# needed to prevent ModuleNotFoundError and duplicated logs.
|
||||
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
|
||||
|
||||
if _MLFLOW_AVAILABLE:
|
||||
__all__.append("MLFlowLogger")
|
||||
|
|
|
@ -22,7 +22,7 @@ import tempfile
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
import yaml
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
@ -37,39 +37,6 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
|||
log = logging.getLogger(__name__)
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
|
||||
if _MLFLOW_AVAILABLE:
|
||||
import mlflow
|
||||
from mlflow.entities import Metric, Param
|
||||
from mlflow.tracking import context, MlflowClient
|
||||
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
|
||||
else:
|
||||
mlflow = None
|
||||
MlflowClient, context = None, None
|
||||
Metric, Param = None, None
|
||||
MLFLOW_RUN_NAME = "mlflow.runName"
|
||||
|
||||
# before v1.1.0
|
||||
if hasattr(context, "resolve_tags"):
|
||||
from mlflow.tracking.context import resolve_tags
|
||||
|
||||
|
||||
# since v1.1.0
|
||||
elif hasattr(context, "registry"):
|
||||
from mlflow.tracking.context.registry import resolve_tags
|
||||
else:
|
||||
|
||||
def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Args:
|
||||
tags: A dictionary of tags to override. If specified, tags passed in this argument will
|
||||
override those inferred from the context.
|
||||
|
||||
Returns: A dictionary of resolved tags.
|
||||
|
||||
Note:
|
||||
See ``mlflow.tracking.context.registry`` for more details.
|
||||
"""
|
||||
return tags
|
||||
|
||||
|
||||
class MLFlowLogger(Logger):
|
||||
|
@ -169,11 +136,13 @@ class MLFlowLogger(Logger):
|
|||
|
||||
self._initialized = False
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
self._mlflow_client = MlflowClient(tracking_uri)
|
||||
|
||||
@property
|
||||
@rank_zero_experiment
|
||||
def experiment(self) -> MlflowClient:
|
||||
def experiment(self) -> Any:
|
||||
r"""
|
||||
Actual MLflow object. To use MLflow features in your
|
||||
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
|
||||
|
@ -183,6 +152,8 @@ class MLFlowLogger(Logger):
|
|||
self.logger.experiment.some_mlflow_function()
|
||||
|
||||
"""
|
||||
import mlflow
|
||||
|
||||
if self._initialized:
|
||||
return self._mlflow_client
|
||||
|
||||
|
@ -207,11 +178,16 @@ class MLFlowLogger(Logger):
|
|||
if self._run_id is None:
|
||||
if self._run_name is not None:
|
||||
self.tags = self.tags or {}
|
||||
|
||||
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
|
||||
|
||||
if MLFLOW_RUN_NAME in self.tags:
|
||||
log.warning(
|
||||
f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
|
||||
)
|
||||
self.tags[MLFLOW_RUN_NAME] = self._run_name
|
||||
|
||||
resolve_tags = _get_resolve_tags()
|
||||
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
|
||||
self._run_id = run.info.run_id
|
||||
self._initialized = True
|
||||
|
@ -244,6 +220,8 @@ class MLFlowLogger(Logger):
|
|||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
|
||||
from mlflow.entities import Param
|
||||
|
||||
# Truncate parameter values to 250 characters.
|
||||
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
|
||||
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
|
||||
|
@ -256,6 +234,8 @@ class MLFlowLogger(Logger):
|
|||
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
||||
from mlflow.entities import Metric
|
||||
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
metrics_list: List[Metric] = []
|
||||
|
||||
|
@ -383,3 +363,18 @@ class MLFlowLogger(Logger):
|
|||
|
||||
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
|
||||
self._logged_model_time[p] = t
|
||||
|
||||
|
||||
def _get_resolve_tags() -> Callable:
|
||||
from mlflow.tracking import context
|
||||
|
||||
# before v1.1.0
|
||||
if hasattr(context, "resolve_tags"):
|
||||
from mlflow.tracking.context import resolve_tags
|
||||
# since v1.1.0
|
||||
elif hasattr(context, "registry"):
|
||||
from mlflow.tracking.context.registry import resolve_tags
|
||||
else:
|
||||
resolve_tags = lambda tags: tags
|
||||
|
||||
return resolve_tags
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mlflow_mock(monkeypatch):
|
||||
mlflow = ModuleType("mlflow")
|
||||
mlflow.set_tracking_uri = Mock()
|
||||
monkeypatch.setitem(sys.modules, "mlflow", mlflow)
|
||||
|
||||
mlflow_tracking = ModuleType("tracking")
|
||||
mlflow_tracking.MlflowClient = Mock()
|
||||
mlflow_tracking.artifact_utils = Mock()
|
||||
monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking)
|
||||
|
||||
mlflow_entities = ModuleType("entities")
|
||||
mlflow_entities.Metric = Mock()
|
||||
mlflow_entities.Param = Mock()
|
||||
mlflow_entities.time = Mock()
|
||||
monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities)
|
||||
|
||||
mlflow.tracking = mlflow_tracking
|
||||
mlflow.entities = mlflow_entities
|
||||
return mlflow
|
|
@ -42,8 +42,7 @@ LOGGER_CTX_MANAGERS = (
|
|||
mock.patch("lightning.pytorch.loggers.comet.comet_ml"),
|
||||
mock.patch("lightning.pytorch.loggers.comet.CometOfflineExperiment"),
|
||||
mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
|
||||
mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"),
|
||||
mock.patch("lightning.pytorch.loggers.mlflow.Metric"),
|
||||
mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()),
|
||||
mock.patch("lightning.pytorch.loggers.neptune.neptune", new_callable=create_neptune_mock),
|
||||
mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
|
||||
mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock),
|
||||
|
@ -82,10 +81,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
|
|||
return logger_class(**args)
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
|
||||
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
|
||||
def test_loggers_fit_test_all(tmpdir, monkeypatch, logger_class):
|
||||
def test_loggers_fit_test_all(logger_class, mlflow_mock, tmpdir):
|
||||
"""Verify that basic functionality of all loggers."""
|
||||
with contextlib.ExitStack() as stack:
|
||||
for mgr in LOGGER_CTX_MANAGERS:
|
||||
|
@ -300,7 +298,7 @@ def _test_logger_initialization(tmpdir, logger_class):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
||||
def test_logger_with_prefix_all(mlflow_mock, monkeypatch, tmpdir):
|
||||
"""Test that prefix is added at the beginning of the metric keys."""
|
||||
prefix = "tmp"
|
||||
|
||||
|
@ -315,10 +313,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
|||
|
||||
# MLflow
|
||||
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
|
||||
"lightning.pytorch.loggers.mlflow.Metric"
|
||||
) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), mock.patch(
|
||||
"lightning.pytorch.loggers.mlflow.mlflow"
|
||||
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
|
||||
):
|
||||
Metric = mlflow_mock.entities.Metric
|
||||
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
|
||||
logger.log_metrics({"test": 1.0}, step=0)
|
||||
logger.experiment.log_batch.assert_called_once_with(
|
||||
|
@ -358,7 +355,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
|||
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
|
||||
|
||||
|
||||
def test_logger_default_name(tmpdir, monkeypatch):
|
||||
def test_logger_default_name(mlflow_mock, monkeypatch, tmpdir):
|
||||
"""Test that the default logger name is lightning_logs."""
|
||||
# CSV
|
||||
logger = CSVLogger(save_dir=tmpdir)
|
||||
|
@ -376,9 +373,10 @@ def test_logger_default_name(tmpdir, monkeypatch):
|
|||
|
||||
# MLflow
|
||||
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
|
||||
"lightning.pytorch.loggers.mlflow.MlflowClient"
|
||||
) as mlflow_client, mock.patch("lightning.pytorch.loggers.mlflow.mlflow"):
|
||||
mlflow_client().get_experiment_by_name.return_value = None
|
||||
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
|
||||
):
|
||||
client = mlflow_mock.tracking.MlflowClient()
|
||||
client.get_experiment_by_name.return_value = None
|
||||
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir)
|
||||
|
||||
_ = logger.experiment
|
||||
|
|
|
@ -19,8 +19,7 @@ import pytest
|
|||
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
|
||||
from lightning.pytorch.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags
|
||||
from lightning.pytorch.loggers.mlflow import _get_resolve_tags, _MLFLOW_AVAILABLE, MLFlowLogger
|
||||
|
||||
|
||||
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
|
||||
|
@ -33,11 +32,12 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
|
|||
return logger
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_exists(client, _, tmpdir):
|
||||
def test_mlflow_logger_exists(_, mlflow_mock, tmp_path):
|
||||
"""Test launching three independent loggers with either same or different experiment name."""
|
||||
client = mlflow_mock.tracking.MlflowClient
|
||||
|
||||
run1 = MagicMock()
|
||||
run1.info.run_id = "run-id-1"
|
||||
run1.info.experiment_id = "exp-id-1"
|
||||
|
@ -53,7 +53,7 @@ def test_mlflow_logger_exists(client, _, tmpdir):
|
|||
client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id
|
||||
client.return_value.create_run = MagicMock(return_value=run1)
|
||||
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
assert logger._experiment_id is None
|
||||
assert logger._run_id is None
|
||||
_ = logger.experiment
|
||||
|
@ -69,7 +69,7 @@ def test_mlflow_logger_exists(client, _, tmpdir):
|
|||
client.return_value.create_run = MagicMock(return_value=run2)
|
||||
|
||||
# same name leads to same experiment id, but different runs get recorded
|
||||
logger2 = MLFlowLogger("test", save_dir=tmpdir)
|
||||
logger2 = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
assert logger2.experiment_id == logger.experiment_id
|
||||
assert logger2.run_id == "run-id-2"
|
||||
assert logger2.experiment.create_experiment.call_count == 0
|
||||
|
@ -82,43 +82,49 @@ def test_mlflow_logger_exists(client, _, tmpdir):
|
|||
client.return_value.create_run = MagicMock(return_value=run3)
|
||||
|
||||
# logger with new experiment name causes new experiment id and new run id to be created
|
||||
logger3 = MLFlowLogger("new", save_dir=tmpdir)
|
||||
logger3 = MLFlowLogger("new", save_dir=str(tmp_path))
|
||||
assert logger3.experiment_id == "exp-id-3" != logger.experiment_id
|
||||
assert logger3.run_id == "run-id-3"
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_run_name_setting(client, _, tmpdir):
|
||||
def test_mlflow_run_name_setting(tmp_path):
|
||||
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
|
||||
if not _MLFLOW_AVAILABLE:
|
||||
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
|
||||
|
||||
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
|
||||
|
||||
resolve_tags = _get_resolve_tags()
|
||||
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
|
||||
|
||||
# run_name is appended to tags
|
||||
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path))
|
||||
logger._mlflow_client = client = Mock()
|
||||
|
||||
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
|
||||
_ = logger.experiment
|
||||
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
|
||||
client.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
|
||||
|
||||
# run_name overrides tags[MLFLOW_RUN_NAME]
|
||||
logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", run_name="run-name-1", tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=str(tmp_path))
|
||||
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
|
||||
_ = logger.experiment
|
||||
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
|
||||
client.create_run.assert_called_with(experiment_id="exp-id", tags=tags)
|
||||
|
||||
# default run_name (= None) does not append new tag
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
|
||||
_ = logger.experiment
|
||||
default_tags = resolve_tags(None)
|
||||
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
|
||||
client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_run_id_setting(client, _, tmpdir):
|
||||
def test_mlflow_run_id_setting(_, mlflow_mock, tmp_path):
|
||||
"""Test that the run_id argument uses the provided run_id."""
|
||||
client = mlflow_mock.tracking.MlflowClient
|
||||
|
||||
run = MagicMock()
|
||||
run.info.run_id = "run-id"
|
||||
run.info.experiment_id = "experiment-id"
|
||||
|
@ -127,7 +133,7 @@ def test_mlflow_run_id_setting(client, _, tmpdir):
|
|||
client.return_value.get_run = MagicMock(return_value=run)
|
||||
|
||||
# run_id exists uses the existing run
|
||||
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=str(tmp_path))
|
||||
_ = logger.experiment
|
||||
client.return_value.get_run.assert_called_with(run.info.run_id)
|
||||
assert logger.experiment_id == run.info.experiment_id
|
||||
|
@ -135,11 +141,12 @@ def test_mlflow_run_id_setting(client, _, tmpdir):
|
|||
client.reset_mock(return_value=True)
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_log_dir(client, _, tmpdir):
|
||||
def test_mlflow_log_dir(_, mlflow_mock, tmp_path):
|
||||
"""Test that the trainer saves checkpoints in the logger's save dir."""
|
||||
client = mlflow_mock.tracking.MlflowClient
|
||||
|
||||
# simulate experiment creation with mlflow client mock
|
||||
run = MagicMock()
|
||||
run.info.run_id = "run-id"
|
||||
|
@ -148,37 +155,39 @@ def test_mlflow_log_dir(client, _, tmpdir):
|
|||
client.return_value.create_run = MagicMock(return_value=run)
|
||||
|
||||
# test construction of default log dir path
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
assert logger.save_dir == tmpdir
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
assert logger.save_dir == str(tmp_path)
|
||||
assert logger.version == "run-id"
|
||||
assert logger.name == "exp-id"
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=1, limit_val_batches=3
|
||||
)
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / "checkpoints")
|
||||
assert trainer.checkpoint_callback.dirpath == str(tmp_path / "exp-id" / "run-id" / "checkpoints")
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=1.ckpt"}
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
|
||||
|
||||
def test_mlflow_logger_dirs_creation(tmpdir):
|
||||
def test_mlflow_logger_dirs_creation(tmp_path):
|
||||
"""Test that the logger creates the folders and files in the right place."""
|
||||
if not _MLFLOW_AVAILABLE:
|
||||
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
|
||||
|
||||
assert not os.listdir(tmpdir)
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
assert logger.save_dir == tmpdir
|
||||
assert set(os.listdir(tmpdir)) == {".trash"}
|
||||
assert not os.listdir(tmp_path)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
assert logger.save_dir == str(tmp_path)
|
||||
assert set(os.listdir(tmp_path)) == {".trash"}
|
||||
run_id = logger.run_id
|
||||
exp_id = logger.experiment_id
|
||||
|
||||
# multiple experiment calls should not lead to new experiment folders
|
||||
for i in range(2):
|
||||
_ = logger.experiment
|
||||
assert set(os.listdir(tmpdir)) == {".trash", exp_id}
|
||||
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
|
||||
assert set(os.listdir(tmp_path)) == {".trash", exp_id}
|
||||
assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"}
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def on_train_epoch_end(self, *args, **kwargs):
|
||||
|
@ -187,56 +196,53 @@ def test_mlflow_logger_dirs_creation(tmpdir):
|
|||
model = CustomModel()
|
||||
limit_batches = 5
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
default_root_dir=tmp_path,
|
||||
logger=logger,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_batches,
|
||||
limit_val_batches=limit_batches,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
|
||||
assert "epoch" in os.listdir(tmpdir / exp_id / run_id / "metrics")
|
||||
assert set(os.listdir(tmpdir / exp_id / run_id / "params")) == model.hparams.keys()
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / "checkpoints")
|
||||
assert set(os.listdir(tmp_path / exp_id)) == {run_id, "meta.yaml"}
|
||||
assert "epoch" in os.listdir(tmp_path / exp_id / run_id / "metrics")
|
||||
assert set(os.listdir(tmp_path / exp_id / run_id / "params")) == model.hparams.keys()
|
||||
assert trainer.checkpoint_callback.dirpath == str(tmp_path / exp_id / run_id / "checkpoints")
|
||||
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
|
||||
def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
|
||||
"""Test that the logger experiment_id retrieved only once."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
_ = logger.experiment
|
||||
_ = logger.experiment
|
||||
_ = logger.experiment
|
||||
assert logger.experiment.get_experiment_by_name.call_count == 1
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.Metric")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
|
||||
def test_mlflow_logger_with_unexpected_characters(_, mlflow_mock, tmp_path):
|
||||
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
metrics = {"[some_metric]": 10}
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="special characters in metric name"):
|
||||
logger.log_metrics(metrics)
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.Metric")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.Param")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.time")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
|
||||
def test_mlflow_logger_experiment_calls(_, mlflow_mock, tmp_path):
|
||||
"""Test that the logger calls methods on the mlflow experiment correctly."""
|
||||
time = mlflow_mock.entities.time
|
||||
metric = mlflow_mock.entities.Metric
|
||||
param = mlflow_mock.entities.Param
|
||||
|
||||
time.return_value = 1
|
||||
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location")
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location")
|
||||
logger._mlflow_client.get_experiment_by_name.return_value = None
|
||||
|
||||
params = {"test": "test_param"}
|
||||
|
@ -260,17 +266,17 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
|
|||
)
|
||||
|
||||
|
||||
def _check_value_length(value, *args, **kwargs):
|
||||
assert len(value) <= 250
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.Param", side_effect=_check_value_length)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
|
||||
def test_mlflow_logger_with_long_param_value(_, mlflow_mock, tmp_path):
|
||||
"""Test that long parameter values are truncated to 250 characters."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
|
||||
def _check_value_length(value, *args, **kwargs):
|
||||
assert len(value) <= 250
|
||||
|
||||
mlflow_mock.entities.Param.side_effect = _check_value_length
|
||||
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
|
||||
params = {"test": "test_param" * 50}
|
||||
logger.log_hyperparams(params)
|
||||
|
@ -279,13 +285,11 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
|
|||
logger.experiment.log_batch.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.Param")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
|
||||
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
def test_mlflow_logger_with_many_params(_, mlflow_mock, tmp_path):
|
||||
"""Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path))
|
||||
|
||||
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
|
||||
logger.log_hyperparams(params)
|
||||
|
@ -301,10 +305,9 @@ def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
|
|||
("finished", "FINISHED"),
|
||||
],
|
||||
)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_finalize(_, __, status, expected):
|
||||
def test_mlflow_logger_finalize(_, mlflow_mock, status, expected):
|
||||
logger = MLFlowLogger("test")
|
||||
|
||||
# Pretend we are in a worker process and finalizing
|
||||
|
@ -315,10 +318,9 @@ def test_mlflow_logger_finalize(_, __, status, expected):
|
|||
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected)
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_logger_finalize_when_exception(*_):
|
||||
def test_mlflow_logger_finalize_when_exception(_, mlflow_mock):
|
||||
logger = MLFlowLogger("test")
|
||||
|
||||
# Pretend we are on the main process and failing
|
||||
|
@ -334,19 +336,20 @@ def test_mlflow_logger_finalize_when_exception(*_):
|
|||
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
@pytest.mark.parametrize("log_model", ["all", True, False])
|
||||
def test_mlflow_log_model(client, _, tmpdir, log_model):
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
def test_mlflow_log_model(_, mlflow_mock, log_model, tmp_path):
|
||||
"""Test that the logger creates the folders and files in the right place."""
|
||||
client = mlflow_mock.tracking.MlflowClient
|
||||
|
||||
# Get model, logger, trainer and train
|
||||
model = BoringModel()
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir, log_model=log_model)
|
||||
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=log_model)
|
||||
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
default_root_dir=tmp_path,
|
||||
logger=logger,
|
||||
max_epochs=2,
|
||||
limit_train_batches=3,
|
||||
|
@ -373,10 +376,9 @@ def test_mlflow_log_model(client, _, tmpdir, log_model):
|
|||
assert not client.return_value.log_artifacts.called
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
|
||||
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow")
|
||||
def test_set_tracking_uri(mlflow_mock, *_):
|
||||
def test_set_tracking_uri(_, mlflow_mock):
|
||||
"""Test that the tracking uri is set for logging artifacts to MLFlow server."""
|
||||
logger = MLFlowLogger(tracking_uri="the_tracking_uri")
|
||||
mlflow_mock.set_tracking_uri.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue