diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 88bed79904..eff5e3305f 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -30,10 +30,23 @@ _MLFLOW_AVAILABLE = _module_available("mlflow") try: import mlflow from mlflow.tracking import MlflowClient + from mlflow.tracking import context # todo: there seems to be still some remaining import error with Conda env except ImportError: _MLFLOW_AVAILABLE = False - mlflow, MlflowClient = None, None + mlflow, MlflowClient, context = None, None, None + + +# before v1.1.0 +if hasattr(context, 'resolve_tags'): + from mlflow.tracking.context import resolve_tags +# since v1.1.0 +elif hasattr(context, 'registry'): + from mlflow.tracking.context.registry import resolve_tags +else: + + def resolve_tags(tags=None): + return tags class MLFlowLogger(LightningLoggerBase): @@ -140,7 +153,7 @@ class MLFlowLogger(LightningLoggerBase): ) if self._run_id is None: - run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) + run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id return self._mlflow_client