fix Logger tests for Win (#605)

* fix mlflow test

* fix mlflow test

* update logger / mlflow

* flake8

* fix appveyor
This commit is contained in:
Jirka Borovec 2019-12-08 01:25:12 +01:00 committed by William Falcon
parent 58cc6e13b9
commit 4970624f8b
5 changed files with 36 additions and 34 deletions

View File

@ -59,7 +59,7 @@ before_test:
# to run your custom scripts instead of automatic tests
test_script:
- py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
#- python setup.py sdist
#- twine check dist/*

View File

@ -1,8 +1,9 @@
from abc import ABC
from functools import wraps
def rank_zero_only(fn):
"""Decorate a logger method to run it only on the process with rank 0
"""Decorate a logger method to run it only on the process with rank 0.
:param fn: Function to decorate
"""
@ -15,12 +16,16 @@ def rank_zero_only(fn):
return wrapped_fn
class LightningLoggerBase(object):
"""Base class for experiment loggers"""
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""
def __init__(self):
self._rank = 0
@property
def experiment(self):
raise NotImplementedError()
def log_metrics(self, metrics, step):
"""Record metrics.
@ -30,46 +35,43 @@ class LightningLoggerBase(object):
raise NotImplementedError()
def log_hyperparams(self, params):
"""Record hyperparameters
"""Record hyperparameters.
:param params: argparse.Namespace containing the hyperparameters
"""
raise NotImplementedError()
def save(self):
"""Save log data"""
"""Save log data."""
pass
def finalize(self, status):
"""Do any processing that is necessary to finalize an experiment
"""Do any processing that is necessary to finalize an experiment.
:param status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass
def close(self):
"""Do any cleanup that is necessary to close an experiment"""
"""Do any cleanup that is necessary to close an experiment."""
pass
@property
def rank(self):
"""
Process rank. In general, metrics should only be logged by the process
with rank 0
"""
"""Process rank. In general, metrics should only be logged by the process with rank 0."""
return self._rank
@rank.setter
def rank(self, value):
"""Set the process rank"""
"""Set the process rank."""
self._rank = value
@property
def name(self):
"""Return the experiment name"""
"""Return the experiment name."""
raise NotImplementedError("Sub-classes must provide a name property")
@property
def version(self):
"""Return the experiment version"""
"""Return the experiment version."""
raise NotImplementedError("Sub-classes must provide a version property")

View File

@ -67,8 +67,8 @@ logger = getLogger(__name__)
class CometLogger(LightningLoggerBase):
def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode)
"""Initialize a Comet.ml logger.
Requires either an API Key (online mode) or a local directory path (offline mode)
:param str api_key: Required in online mode. API key, found on Comet.ml
:param str save_dir: Required in offline mode. The path for the directory to save local comet logs

View File

@ -40,25 +40,29 @@ logger = getLogger(__name__)
class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name, tracking_uri=None, tags=None):
super().__init__()
self.experiment = mlflow.tracking.MlflowClient(tracking_uri)
self._mlflow_client = mlflow.tracking.MlflowClient(tracking_uri)
self.experiment_name = experiment_name
self._run_id = None
self.tags = tags
@property
def experiment(self):
return self._mlflow_client
@property
def run_id(self):
if self._run_id is not None:
return self._run_id
experiment = self.experiment.get_experiment_by_name(self.experiment_name)
if experiment is None:
logger.warning(
f"Experiment with name f{self.experiment_name} not found. Creating it."
)
self.experiment.create_experiment(self.experiment_name)
experiment = self.experiment.get_experiment_by_name(self.experiment_name)
expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)
run = self.experiment.create_run(experiment.experiment_id, tags=self.tags)
if expt:
self._expt_id = expt.experiment_id
else:
logger.warning(f"Experiment with name f{self.experiment_name} not found. Creating it.")
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)
run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
self._run_id = run.info.run_id
return self._run_id

View File

@ -63,15 +63,13 @@ def test_mlflow_logger(tmpdir):
model = LightningTestModel(hparams)
mlflow_dir = os.path.join(tmpdir, "mlruns")
logger = MLFlowLogger("test", f"file://{mlflow_dir}")
logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}")
trainer_options = dict(
max_epochs=1,
train_percent_check=0.01,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
@ -88,13 +86,11 @@ def test_mlflow_pickle(tmpdir):
except ModuleNotFoundError:
return
hparams = tutils.get_hparams()
model = LightningTestModel(hparams)
# hparams = tutils.get_hparams()
# model = LightningTestModel(hparams)
mlflow_dir = os.path.join(tmpdir, "mlruns")
logger = MLFlowLogger("test", f"file://{mlflow_dir}")
logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}")
trainer_options = dict(
max_epochs=1,
logger=logger