Initialize loggers only once (#270)

* Create underlying loggers lazily

This avoids creating duplicate experiments or run in multi-node DDP.

* Save hyperparameters automatically

* Update docs for snapshotting hyperparams

* Fix test tube

* Fix test tube pickling
This commit is contained in:
Nic Eggert 2019-10-02 10:10:40 -05:00 committed by William Falcon
parent 222d7d2d5d
commit 614cb3c03b
5 changed files with 63 additions and 26 deletions

View File

@ -119,12 +119,21 @@ trainer = Trainer(process_position=1)
---
#### Save a snapshot of all hyperparameters
Log hyperparameters using the logger
Automatically log hyperparameters stored in the `hparams` attribute as an `argparse.Namespace`
``` {.python}
logger = TestTubeLogger(...)
logger.log_hyperparams(args)
Trainer(logger=logger)
class MyModel(pl.Lightning):
def __init__(self, hparams):
self.hparams = hparams
...
args = parser.parse_args()
model = MyModel(args)
logger = TestTubeLogger(...)
t = Trainer(logger=logger)
trainer.fit(model)
```
---

View File

@ -9,21 +9,28 @@ logger = getLogger(__name__)
class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name, tracking_uri=None):
super().__init__()
self.client = mlflow.tracking.MlflowClient(tracking_uri)
self.experiment_name = experiment_name
self._run_id = None
experiment = self.client.get_experiment_by_name(experiment_name)
@property
def run_id(self):
if self._run_id is not None:
return self._run_id
experiment = self.client.get_experiment_by_name(self.experiment_name)
if experiment is None:
logger.warning(
f"Experiment with name f{experiment_name} not found. Creating it."
f"Experiment with name f{self.experiment_name} not found. Creating it."
)
self.client.create_experiment(experiment_name)
experiment = self.client.get_experiment_by_name(experiment_name)
self.client.create_experiment(self.experiment_name)
experiment = self.client.get_experiment_by_name(self.experiment_name)
run = self.client.create_run(experiment.experiment_id)
self.run_id = run.info.run_id
self._run_id = run.info.run_id
return self._run_id
@rank_zero_only
def log_hyperparams(self, params):

View File

@ -11,13 +11,26 @@ class TestTubeLogger(LightningLoggerBase):
self, save_dir, name="default", debug=False, version=None, create_git_tag=False
):
super().__init__()
self.experiment = Experiment(
save_dir=save_dir,
name=name,
debug=debug,
version=version,
create_git_tag=create_git_tag,
self.save_dir = save_dir
self.name = name
self.debug = debug
self._version = version
self.create_git_tag = create_git_tag
self._experiment = None
@property
def experiment(self):
if self._experiment is not None:
return self._experiment
self._experiment = Experiment(
save_dir=self.save_dir,
name=self.name,
debug=self.debug,
version=self.version,
create_git_tag=self.create_git_tag,
rank=self.rank,
)
return self._experiment
@rank_zero_only
def log_hyperparams(self, params):
@ -41,14 +54,23 @@ class TestTubeLogger(LightningLoggerBase):
@property
def rank(self):
if self._experiment is None:
return self._rank
else:
return self.experiment.rank
@rank.setter
def rank(self, value):
self.experiment.rank = value
if self._experiment is None:
self._rank = value
else:
return self.experiment.rank
@property
def version(self):
if self._experiment is None:
return self._version
else:
return self.experiment.version
# Test tube experiments are not pickleable, so we need to override a few
@ -57,10 +79,10 @@ class TestTubeLogger(LightningLoggerBase):
# for more info.
def __getstate__(self):
state = self.__dict__.copy()
state["experiment"] = self.experiment.get_meta_copy()
state["_experiment"] = self.experiment.get_meta_copy()
return state
def __setstate__(self, state):
self.experiment = state["experiment"].get_non_ddp_exp()
del state['experiment']
self._experiment = state["_experiment"].get_non_ddp_exp()
del state["_experiment"]
self.__dict__.update(state)

View File

@ -890,7 +890,8 @@ class Trainer(TrainerIO):
ref_model.logger = self.logger
# save exp to get started
if self.proc_rank == 0:
if hasattr(ref_model, "hparams"):
self.logger.log_hyperparams(ref_model.hparams)
self.logger.save()
# track model now.

View File

@ -19,8 +19,6 @@ def test_testtube_logger():
save_dir = init_save_dir()
logger = get_test_tube_logger(False)
logger.log_hyperparams(hparams)
logger.save()
trainer_options = dict(
max_nb_epochs=1,