add tags argument to MLFlowLogger (#349)
This commit is contained in:
parent
453568179b
commit
0eab1e42b2
|
@ -9,11 +9,12 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class MLFlowLogger(LightningLoggerBase):
|
||||
def __init__(self, experiment_name, tracking_uri=None):
|
||||
def __init__(self, experiment_name, tracking_uri=None, tags=None):
|
||||
super().__init__()
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri)
|
||||
self.experiment_name = experiment_name
|
||||
self._run_id = None
|
||||
self.tags = tags
|
||||
|
||||
@property
|
||||
def run_id(self):
|
||||
|
@ -28,7 +29,7 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
self.client.create_experiment(self.experiment_name)
|
||||
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
|
||||
run = self.client.create_run(experiment.experiment_id)
|
||||
run = self.client.create_run(experiment.experiment_id, tags=self.tags)
|
||||
self._run_id = run.info.run_id
|
||||
return self._run_id
|
||||
|
||||
|
|
Loading…
Reference in New Issue