diff --git a/pytorch_lightning/logging/mlflow_logger.py b/pytorch_lightning/logging/mlflow_logger.py index 970dfe0309..8fb79f5366 100644 --- a/pytorch_lightning/logging/mlflow_logger.py +++ b/pytorch_lightning/logging/mlflow_logger.py @@ -9,11 +9,12 @@ logger = getLogger(__name__) class MLFlowLogger(LightningLoggerBase): - def __init__(self, experiment_name, tracking_uri=None): + def __init__(self, experiment_name, tracking_uri=None, tags=None): super().__init__() self.client = mlflow.tracking.MlflowClient(tracking_uri) self.experiment_name = experiment_name self._run_id = None + self.tags = tags @property def run_id(self): @@ -28,7 +29,7 @@ class MLFlowLogger(LightningLoggerBase): self.client.create_experiment(self.experiment_name) experiment = self.client.get_experiment_by_name(self.experiment_name) - run = self.client.create_run(experiment.experiment_id) + run = self.client.create_run(experiment.experiment_id, tags=self.tags) self._run_id = run.info.run_id return self._run_id