Mocking Loggers (part 4c, mlflow) (#3889)
* base * add xfail * new test * import * missing import Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
d71ed277d4
commit
9928125768
|
@ -1,7 +1,10 @@
|
|||
import importlib.util
|
||||
import os
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -62,8 +65,42 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
|
|||
assert logger3.run_id == "run-id-3"
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
|
||||
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
|
||||
def test_mlflow_log_dir(client, mlflow, tmpdir):
|
||||
""" Test that the trainer saves checkpoints in the logger's save dir. """
|
||||
|
||||
# simulate experiment creation with mlflow client mock
|
||||
run = MagicMock()
|
||||
run.info.run_id = "run-id"
|
||||
client.return_value.get_experiment_by_name = MagicMock(return_value=None)
|
||||
client.return_value.create_experiment = MagicMock(return_value="exp-id")
|
||||
client.return_value.create_run = MagicMock(return_value=run)
|
||||
|
||||
# test construction of default log dir path
|
||||
logger = MLFlowLogger("test", save_dir=tmpdir)
|
||||
assert logger.save_dir == tmpdir
|
||||
assert logger.version == "run-id"
|
||||
assert logger.name == "exp-id"
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
logger=logger,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=3,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
|
||||
|
||||
def test_mlflow_logger_dirs_creation(tmpdir):
|
||||
""" Test that the logger creates the folders and files in the right place. """
|
||||
if not importlib.util.find_spec('mlflow'):
|
||||
pytest.xfail(f"test for explicit file creation requires mlflow dependency to be installed.")
|
||||
|
||||
assert not os.listdir(tmpdir)
|
||||
logger = MLFlowLogger('test', save_dir=tmpdir)
|
||||
assert logger.save_dir == tmpdir
|
||||
|
|
Loading…
Reference in New Issue