diff --git a/docs/Trainer/Logging.md b/docs/Trainer/Logging.md index 765e1afd81..b6d5e47cce 100644 --- a/docs/Trainer/Logging.md +++ b/docs/Trainer/Logging.md @@ -119,12 +119,21 @@ trainer = Trainer(process_position=1) --- #### Save a snapshot of all hyperparameters -Log hyperparameters using the logger +Automatically log hyperparameters stored in the `hparams` attribute as an `argparse.Namespace` ``` {.python} -logger = TestTubeLogger(...) -logger.log_hyperparams(args) -Trainer(logger=logger) +class MyModel(pl.Lightning): + def __init__(self, hparams): + self.hparams = hparams + + ... + +args = parser.parse_args() +model = MyModel(args) + +logger = TestTubeLogger(...) +t = Trainer(logger=logger) +trainer.fit(model) ``` --- diff --git a/pytorch_lightning/logging/mlflow_logger.py b/pytorch_lightning/logging/mlflow_logger.py index cc76ec4d83..70609ea334 100644 --- a/pytorch_lightning/logging/mlflow_logger.py +++ b/pytorch_lightning/logging/mlflow_logger.py @@ -9,21 +9,28 @@ logger = getLogger(__name__) class MLFlowLogger(LightningLoggerBase): - def __init__(self, experiment_name, tracking_uri=None): super().__init__() self.client = mlflow.tracking.MlflowClient(tracking_uri) + self.experiment_name = experiment_name + self._run_id = None - experiment = self.client.get_experiment_by_name(experiment_name) + @property + def run_id(self): + if self._run_id is not None: + return self._run_id + + experiment = self.client.get_experiment_by_name(self.experiment_name) if experiment is None: logger.warning( - f"Experiment with name f{experiment_name} not found. Creating it." + f"Experiment with name f{self.experiment_name} not found. Creating it." ) - self.client.create_experiment(experiment_name) - experiment = self.client.get_experiment_by_name(experiment_name) + self.client.create_experiment(self.experiment_name) + experiment = self.client.get_experiment_by_name(self.experiment_name) run = self.client.create_run(experiment.experiment_id) - self.run_id = run.info.run_id + self._run_id = run.info.run_id + return self._run_id @rank_zero_only def log_hyperparams(self, params): diff --git a/pytorch_lightning/logging/test_tube_logger.py b/pytorch_lightning/logging/test_tube_logger.py index 51222efe0e..c20cc1d3bf 100644 --- a/pytorch_lightning/logging/test_tube_logger.py +++ b/pytorch_lightning/logging/test_tube_logger.py @@ -11,13 +11,26 @@ class TestTubeLogger(LightningLoggerBase): self, save_dir, name="default", debug=False, version=None, create_git_tag=False ): super().__init__() - self.experiment = Experiment( - save_dir=save_dir, - name=name, - debug=debug, - version=version, - create_git_tag=create_git_tag, + self.save_dir = save_dir + self.name = name + self.debug = debug + self._version = version + self.create_git_tag = create_git_tag + self._experiment = None + + @property + def experiment(self): + if self._experiment is not None: + return self._experiment + self._experiment = Experiment( + save_dir=self.save_dir, + name=self.name, + debug=self.debug, + version=self.version, + create_git_tag=self.create_git_tag, + rank=self.rank, ) + return self._experiment @rank_zero_only def log_hyperparams(self, params): @@ -41,15 +54,24 @@ class TestTubeLogger(LightningLoggerBase): @property def rank(self): - return self.experiment.rank + if self._experiment is None: + return self._rank + else: + return self.experiment.rank @rank.setter def rank(self, value): - self.experiment.rank = value + if self._experiment is None: + self._rank = value + else: + return self.experiment.rank @property def version(self): - return self.experiment.version + if self._experiment is None: + return self._version + else: + return self.experiment.version # Test tube experiments are not pickleable, so we need to override a few # methods to get DDP working. See @@ -57,10 +79,10 @@ class TestTubeLogger(LightningLoggerBase): # for more info. def __getstate__(self): state = self.__dict__.copy() - state["experiment"] = self.experiment.get_meta_copy() + state["_experiment"] = self.experiment.get_meta_copy() return state def __setstate__(self, state): - self.experiment = state["experiment"].get_non_ddp_exp() - del state['experiment'] + self._experiment = state["_experiment"].get_non_ddp_exp() + del state["_experiment"] self.__dict__.update(state) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6c02d9b3fd..ececc7b317 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -890,8 +890,9 @@ class Trainer(TrainerIO): ref_model.logger = self.logger # save exp to get started - if self.proc_rank == 0: - self.logger.save() + if hasattr(ref_model, "hparams"): + self.logger.log_hyperparams(ref_model.hparams) + self.logger.save() # track model now. # if cluster resets state, the model will update with the saved weights diff --git a/tests/test_logging.py b/tests/test_logging.py index 1a78720795..742ef7a76b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -19,8 +19,6 @@ def test_testtube_logger(): save_dir = init_save_dir() logger = get_test_tube_logger(False) - logger.log_hyperparams(hparams) - logger.save() trainer_options = dict( max_nb_epochs=1,