From 99281257689154e1b3896c4d3fcd33a0616eae25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Oct 2020 03:55:59 +0200 Subject: [PATCH] Mocking Loggers (part 4c, mlflow) (#3889) * base * add xfail * new test * import * missing import Co-authored-by: William Falcon --- tests/loggers/test_mlflow.py | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 17ca67d391..ac64717d20 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -1,7 +1,10 @@ +import importlib.util import os from unittest import mock from unittest.mock import MagicMock +import pytest + from mlflow.tracking import MlflowClient from pytorch_lightning import Trainer @@ -62,8 +65,42 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir): assert logger3.run_id == "run-id-3" +@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_log_dir(client, mlflow, tmpdir): + """ Test that the trainer saves checkpoints in the logger's save dir. """ + + # simulate experiment creation with mlflow client mock + run = MagicMock() + run.info.run_id = "run-id" + client.return_value.get_experiment_by_name = MagicMock(return_value=None) + client.return_value.create_experiment = MagicMock(return_value="exp-id") + client.return_value.create_run = MagicMock(return_value=run) + + # test construction of default log dir path + logger = MLFlowLogger("test", save_dir=tmpdir) + assert logger.save_dir == tmpdir + assert logger.version == "run-id" + assert logger.name == "exp-id" + + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + logger=logger, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=3, + ) + trainer.fit(model) + assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + + def test_mlflow_logger_dirs_creation(tmpdir): """ Test that the logger creates the folders and files in the right place. """ + if not importlib.util.find_spec('mlflow'): + pytest.xfail(f"test for explicit file creation requires mlflow dependency to be installed.") + assert not os.listdir(tmpdir) logger = MLFlowLogger('test', save_dir=tmpdir) assert logger.save_dir == tmpdir