2020-03-03 01:49:14 +00:00
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
import tests.base.utils as tutils
|
2020-03-03 01:49:14 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.loggers import MLFlowLogger
|
2020-03-25 11:46:27 +00:00
|
|
|
from tests.base import LightningTestModel
|
2020-03-03 01:49:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_mlflow_logger(tmpdir):
|
|
|
|
"""Verify that basic functionality of mlflow logger works."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-03-03 01:49:14 +00:00
|
|
|
model = LightningTestModel(hparams)
|
|
|
|
|
|
|
|
mlflow_dir = os.path.join(tmpdir, 'mlruns')
|
|
|
|
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
|
|
|
|
|
|
|
# Test already exists
|
|
|
|
logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
|
|
|
_ = logger2.run_id
|
|
|
|
|
|
|
|
# Try logging string
|
|
|
|
logger.log_metrics({'acc': 'test'})
|
|
|
|
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
train_percent_check=0.05,
|
|
|
|
logger=logger
|
|
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
assert result == 1, 'Training failed'
|
|
|
|
|
|
|
|
|
|
|
|
def test_mlflow_pickle(tmpdir):
|
|
|
|
"""Verify that pickling trainer with mlflow logger works."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
mlflow_dir = os.path.join(tmpdir, 'mlruns')
|
|
|
|
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
logger=logger
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
pkl_bytes = pickle.dumps(trainer)
|
|
|
|
trainer2 = pickle.loads(pkl_bytes)
|
|
|
|
trainer2.logger.log_metrics({'acc': 1.0})
|