diff --git a/appveyor.yml b/appveyor.yml index 98e259cf90..9a9319c7f6 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -59,7 +59,7 @@ before_test: # to run your custom scripts instead of automatic tests test_script: - - py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8 + - coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8 #- python setup.py sdist #- twine check dist/* diff --git a/pytorch_lightning/logging/base.py b/pytorch_lightning/logging/base.py index 84d25cbef8..76bd0d2d9a 100644 --- a/pytorch_lightning/logging/base.py +++ b/pytorch_lightning/logging/base.py @@ -1,8 +1,9 @@ +from abc import ABC from functools import wraps def rank_zero_only(fn): - """Decorate a logger method to run it only on the process with rank 0 + """Decorate a logger method to run it only on the process with rank 0. :param fn: Function to decorate """ @@ -15,12 +16,16 @@ def rank_zero_only(fn): return wrapped_fn -class LightningLoggerBase(object): - """Base class for experiment loggers""" +class LightningLoggerBase(ABC): + """Base class for experiment loggers.""" def __init__(self): self._rank = 0 + @property + def experiment(self): + raise NotImplementedError() + def log_metrics(self, metrics, step): """Record metrics. @@ -30,46 +35,43 @@ class LightningLoggerBase(object): raise NotImplementedError() def log_hyperparams(self, params): - """Record hyperparameters + """Record hyperparameters. :param params: argparse.Namespace containing the hyperparameters """ raise NotImplementedError() def save(self): - """Save log data""" + """Save log data.""" pass def finalize(self, status): - """Do any processing that is necessary to finalize an experiment + """Do any processing that is necessary to finalize an experiment. :param status: Status that the experiment finished with (e.g. success, failed, aborted) """ pass def close(self): - """Do any cleanup that is necessary to close an experiment""" + """Do any cleanup that is necessary to close an experiment.""" pass @property def rank(self): - """ - Process rank. In general, metrics should only be logged by the process - with rank 0 - """ + """Process rank. In general, metrics should only be logged by the process with rank 0.""" return self._rank @rank.setter def rank(self, value): - """Set the process rank""" + """Set the process rank.""" self._rank = value @property def name(self): - """Return the experiment name""" + """Return the experiment name.""" raise NotImplementedError("Sub-classes must provide a name property") @property def version(self): - """Return the experiment version""" + """Return the experiment version.""" raise NotImplementedError("Sub-classes must provide a version property") diff --git a/pytorch_lightning/logging/comet.py b/pytorch_lightning/logging/comet.py index cbf998388a..b74760665b 100644 --- a/pytorch_lightning/logging/comet.py +++ b/pytorch_lightning/logging/comet.py @@ -67,8 +67,8 @@ logger = getLogger(__name__) class CometLogger(LightningLoggerBase): def __init__(self, api_key=None, save_dir=None, workspace=None, rest_api_key=None, project_name=None, experiment_name=None, **kwargs): - """ - Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode) + """Initialize a Comet.ml logger. + Requires either an API Key (online mode) or a local directory path (offline mode) :param str api_key: Required in online mode. API key, found on Comet.ml :param str save_dir: Required in offline mode. The path for the directory to save local comet logs diff --git a/pytorch_lightning/logging/mlflow.py b/pytorch_lightning/logging/mlflow.py index 6ed41ee1d6..fde2945b27 100644 --- a/pytorch_lightning/logging/mlflow.py +++ b/pytorch_lightning/logging/mlflow.py @@ -40,25 +40,29 @@ logger = getLogger(__name__) class MLFlowLogger(LightningLoggerBase): def __init__(self, experiment_name, tracking_uri=None, tags=None): super().__init__() - self.experiment = mlflow.tracking.MlflowClient(tracking_uri) + self._mlflow_client = mlflow.tracking.MlflowClient(tracking_uri) self.experiment_name = experiment_name self._run_id = None self.tags = tags + @property + def experiment(self): + return self._mlflow_client + @property def run_id(self): if self._run_id is not None: return self._run_id - experiment = self.experiment.get_experiment_by_name(self.experiment_name) - if experiment is None: - logger.warning( - f"Experiment with name f{self.experiment_name} not found. Creating it." - ) - self.experiment.create_experiment(self.experiment_name) - experiment = self.experiment.get_experiment_by_name(self.experiment_name) + expt = self._mlflow_client.get_experiment_by_name(self.experiment_name) - run = self.experiment.create_run(experiment.experiment_id, tags=self.tags) + if expt: + self._expt_id = expt.experiment_id + else: + logger.warning(f"Experiment with name f{self.experiment_name} not found. Creating it.") + self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name) + + run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags) self._run_id = run.info.run_id return self._run_id diff --git a/tests/test_logging.py b/tests/test_logging.py index 3d36359761..bf35c2cc46 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -63,15 +63,13 @@ def test_mlflow_logger(tmpdir): model = LightningTestModel(hparams) mlflow_dir = os.path.join(tmpdir, "mlruns") - - logger = MLFlowLogger("test", f"file://{mlflow_dir}") + logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}") trainer_options = dict( max_epochs=1, train_percent_check=0.01, logger=logger ) - trainer = Trainer(**trainer_options) result = trainer.fit(model) @@ -88,13 +86,11 @@ def test_mlflow_pickle(tmpdir): except ModuleNotFoundError: return - hparams = tutils.get_hparams() - model = LightningTestModel(hparams) + # hparams = tutils.get_hparams() + # model = LightningTestModel(hparams) mlflow_dir = os.path.join(tmpdir, "mlruns") - - logger = MLFlowLogger("test", f"file://{mlflow_dir}") - + logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}") trainer_options = dict( max_epochs=1, logger=logger