add tags argument to MLFlowLogger (#349)

This commit is contained in:
festeh 2019-10-10 00:47:17 +03:00 committed by William Falcon
parent 453568179b
commit 0eab1e42b2
1 changed files with 3 additions and 2 deletions

View File

@ -9,11 +9,12 @@ logger = getLogger(__name__)
class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name, tracking_uri=None):
def __init__(self, experiment_name, tracking_uri=None, tags=None):
super().__init__()
self.client = mlflow.tracking.MlflowClient(tracking_uri)
self.experiment_name = experiment_name
self._run_id = None
self.tags = tags
@property
def run_id(self):
@ -28,7 +29,7 @@ class MLFlowLogger(LightningLoggerBase):
self.client.create_experiment(self.experiment_name)
experiment = self.client.get_experiment_by_name(self.experiment_name)
run = self.client.create_run(experiment.experiment_id)
run = self.client.create_run(experiment.experiment_id, tags=self.tags)
self._run_id = run.info.run_id
return self._run_id