Initialize loggers only once (#270)
* Create underlying loggers lazily This avoids creating duplicate experiments or run in multi-node DDP. * Save hyperparameters automatically * Update docs for snapshotting hyperparams * Fix test tube * Fix test tube pickling
This commit is contained in:
parent
222d7d2d5d
commit
614cb3c03b
|
@ -119,12 +119,21 @@ trainer = Trainer(process_position=1)
|
|||
|
||||
---
|
||||
#### Save a snapshot of all hyperparameters
|
||||
Log hyperparameters using the logger
|
||||
Automatically log hyperparameters stored in the `hparams` attribute as an `argparse.Namespace`
|
||||
``` {.python}
|
||||
logger = TestTubeLogger(...)
|
||||
logger.log_hyperparams(args)
|
||||
|
||||
Trainer(logger=logger)
|
||||
class MyModel(pl.Lightning):
|
||||
def __init__(self, hparams):
|
||||
self.hparams = hparams
|
||||
|
||||
...
|
||||
|
||||
args = parser.parse_args()
|
||||
model = MyModel(args)
|
||||
|
||||
logger = TestTubeLogger(...)
|
||||
t = Trainer(logger=logger)
|
||||
trainer.fit(model)
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -9,21 +9,28 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class MLFlowLogger(LightningLoggerBase):
|
||||
|
||||
def __init__(self, experiment_name, tracking_uri=None):
|
||||
super().__init__()
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri)
|
||||
self.experiment_name = experiment_name
|
||||
self._run_id = None
|
||||
|
||||
experiment = self.client.get_experiment_by_name(experiment_name)
|
||||
@property
|
||||
def run_id(self):
|
||||
if self._run_id is not None:
|
||||
return self._run_id
|
||||
|
||||
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
if experiment is None:
|
||||
logger.warning(
|
||||
f"Experiment with name f{experiment_name} not found. Creating it."
|
||||
f"Experiment with name f{self.experiment_name} not found. Creating it."
|
||||
)
|
||||
self.client.create_experiment(experiment_name)
|
||||
experiment = self.client.get_experiment_by_name(experiment_name)
|
||||
self.client.create_experiment(self.experiment_name)
|
||||
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
|
||||
run = self.client.create_run(experiment.experiment_id)
|
||||
self.run_id = run.info.run_id
|
||||
self._run_id = run.info.run_id
|
||||
return self._run_id
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
|
|
|
@ -11,13 +11,26 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
self, save_dir, name="default", debug=False, version=None, create_git_tag=False
|
||||
):
|
||||
super().__init__()
|
||||
self.experiment = Experiment(
|
||||
save_dir=save_dir,
|
||||
name=name,
|
||||
debug=debug,
|
||||
version=version,
|
||||
create_git_tag=create_git_tag,
|
||||
self.save_dir = save_dir
|
||||
self.name = name
|
||||
self.debug = debug
|
||||
self._version = version
|
||||
self.create_git_tag = create_git_tag
|
||||
self._experiment = None
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
if self._experiment is not None:
|
||||
return self._experiment
|
||||
self._experiment = Experiment(
|
||||
save_dir=self.save_dir,
|
||||
name=self.name,
|
||||
debug=self.debug,
|
||||
version=self.version,
|
||||
create_git_tag=self.create_git_tag,
|
||||
rank=self.rank,
|
||||
)
|
||||
return self._experiment
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
|
@ -41,14 +54,23 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
|
||||
@property
|
||||
def rank(self):
|
||||
if self._experiment is None:
|
||||
return self._rank
|
||||
else:
|
||||
return self.experiment.rank
|
||||
|
||||
@rank.setter
|
||||
def rank(self, value):
|
||||
self.experiment.rank = value
|
||||
if self._experiment is None:
|
||||
self._rank = value
|
||||
else:
|
||||
return self.experiment.rank
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
if self._experiment is None:
|
||||
return self._version
|
||||
else:
|
||||
return self.experiment.version
|
||||
|
||||
# Test tube experiments are not pickleable, so we need to override a few
|
||||
|
@ -57,10 +79,10 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
# for more info.
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["experiment"] = self.experiment.get_meta_copy()
|
||||
state["_experiment"] = self.experiment.get_meta_copy()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.experiment = state["experiment"].get_non_ddp_exp()
|
||||
del state['experiment']
|
||||
self._experiment = state["_experiment"].get_non_ddp_exp()
|
||||
del state["_experiment"]
|
||||
self.__dict__.update(state)
|
||||
|
|
|
@ -890,7 +890,8 @@ class Trainer(TrainerIO):
|
|||
ref_model.logger = self.logger
|
||||
|
||||
# save exp to get started
|
||||
if self.proc_rank == 0:
|
||||
if hasattr(ref_model, "hparams"):
|
||||
self.logger.log_hyperparams(ref_model.hparams)
|
||||
self.logger.save()
|
||||
|
||||
# track model now.
|
||||
|
|
|
@ -19,8 +19,6 @@ def test_testtube_logger():
|
|||
save_dir = init_save_dir()
|
||||
|
||||
logger = get_test_tube_logger(False)
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
|
Loading…
Reference in New Issue