2020-07-09 11:15:41 +00:00
|
|
|
import os
|
|
|
|
|
2020-09-09 09:38:26 +00:00
|
|
|
from unittest import mock
|
|
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
|
2020-07-09 11:15:41 +00:00
|
|
|
from pytorch_lightning import Trainer
|
2020-03-03 01:49:14 +00:00
|
|
|
from pytorch_lightning.loggers import MLFlowLogger
|
2020-07-09 11:15:41 +00:00
|
|
|
from tests.base import EvalModelTemplate
|
2020-03-03 01:49:14 +00:00
|
|
|
|
|
|
|
|
2020-04-15 00:32:33 +00:00
|
|
|
def test_mlflow_logger_exists(tmpdir):
|
2020-07-09 11:15:41 +00:00
|
|
|
""" Test launching two independent loggers. """
|
2020-04-15 00:32:33 +00:00
|
|
|
logger = MLFlowLogger('test', save_dir=tmpdir)
|
2020-07-09 11:15:41 +00:00
|
|
|
# same name leads to same experiment id, but different runs get recorded
|
2020-04-15 00:32:33 +00:00
|
|
|
logger2 = MLFlowLogger('test', save_dir=tmpdir)
|
2020-07-09 11:15:41 +00:00
|
|
|
assert logger.experiment_id == logger2.experiment_id
|
2020-04-15 00:32:33 +00:00
|
|
|
assert logger.run_id != logger2.run_id
|
2020-07-09 11:15:41 +00:00
|
|
|
logger3 = MLFlowLogger('new', save_dir=tmpdir)
|
|
|
|
assert logger3.experiment_id != logger.experiment_id
|
|
|
|
|
|
|
|
|
|
|
|
def test_mlflow_logger_dirs_creation(tmpdir):
|
|
|
|
""" Test that the logger creates the folders and files in the right place. """
|
|
|
|
assert not os.listdir(tmpdir)
|
|
|
|
logger = MLFlowLogger('test', save_dir=tmpdir)
|
|
|
|
assert logger.save_dir == tmpdir
|
|
|
|
assert set(os.listdir(tmpdir)) == {'.trash'}
|
|
|
|
run_id = logger.run_id
|
|
|
|
exp_id = logger.experiment_id
|
|
|
|
|
|
|
|
# multiple experiment calls should not lead to new experiment folders
|
|
|
|
for i in range(2):
|
|
|
|
_ = logger.experiment
|
|
|
|
assert set(os.listdir(tmpdir)) == {'.trash', exp_id}
|
|
|
|
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
|
|
|
|
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
|
|
|
|
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
|
|
|
|
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
|
2020-07-27 16:53:11 +00:00
|
|
|
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
|
|
|
|
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
2020-09-09 09:38:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_mlflow_experiment_id_retrieved_once(tmpdir):
|
|
|
|
logger = MLFlowLogger('test', save_dir=tmpdir)
|
|
|
|
get_experiment_name = logger._mlflow_client.get_experiment_by_name
|
|
|
|
with mock.patch.object(MlflowClient, 'get_experiment_by_name', wraps=get_experiment_name) as mocked:
|
|
|
|
_ = logger.experiment
|
|
|
|
_ = logger.experiment
|
|
|
|
_ = logger.experiment
|
|
|
|
assert mocked.call_count == 1
|