Mocking Loggers (part 4c, mlflow) (#3889)

* base

* add xfail

* new test

* import

* missing import

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Adrian Wälchli 2020-10-07 03:55:59 +02:00 committed by GitHub
parent d71ed277d4
commit 9928125768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 0 deletions

View File

@ -1,7 +1,10 @@
import importlib.util
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from mlflow.tracking import MlflowClient
from pytorch_lightning import Trainer
@ -62,8 +65,42 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
assert logger3.run_id == "run-id-3"
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
""" Test that the trainer saves checkpoints in the logger's save dir. """
# simulate experiment creation with mlflow client mock
run = MagicMock()
run.info.run_id = "run-id"
client.return_value.get_experiment_by_name = MagicMock(return_value=None)
client.return_value.create_experiment = MagicMock(return_value="exp-id")
client.return_value.create_run = MagicMock(return_value=run)
# test construction of default log dir path
logger = MLFlowLogger("test", save_dir=tmpdir)
assert logger.save_dir == tmpdir
assert logger.version == "run-id"
assert logger.name == "exp-id"
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
logger=logger,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=3,
)
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
def test_mlflow_logger_dirs_creation(tmpdir):
""" Test that the logger creates the folders and files in the right place. """
if not importlib.util.find_spec('mlflow'):
pytest.xfail(f"test for explicit file creation requires mlflow dependency to be installed.")
assert not os.listdir(tmpdir)
logger = MLFlowLogger('test', save_dir=tmpdir)
assert logger.save_dir == tmpdir