lightning/tests/loggers/test_mlflow.py

55 lines
1.5 KiB
Python

import os
import pickle
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
from tests.base import LightningTestModel
def test_mlflow_logger(tmpdir):
"""Verify that basic functionality of mlflow logger works."""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
mlflow_dir = os.path.join(tmpdir, 'mlruns')
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
# Test already exists
logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
_ = logger2.run_id
# Try logging string
logger.log_metrics({'acc': 'test'})
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'Training failed'
def test_mlflow_pickle(tmpdir):
"""Verify that pickling trainer with mlflow logger works."""
tutils.reset_seed()
mlflow_dir = os.path.join(tmpdir, 'mlruns')
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
logger=logger
)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})