From 3007872d018c30ec4f81b6604e2295f62554cb46 Mon Sep 17 00:00:00 2001 From: Oleg Date: Thu, 8 Apr 2021 16:45:23 +0700 Subject: [PATCH] Update mlflow with using resolve_tags (#6746) * Update mlflow.py #6745 adds additional info about the run, as in the native API * Update mlflow.py trying to fix some backward compatibility issues with `resolve_tags` * wip on backward compatibility added a default for `getattr` in case the `registry` object exists, but has no proper attribute (weird case but who knows...) * fix pep * impoert * fix registry import * try fix failing tests removed the first if statement, so that `resolve_tags` would be defined either case * fix formatting Co-authored-by: Jirka Borovec --- pytorch_lightning/loggers/mlflow.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 88bed79904..eff5e3305f 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -30,10 +30,23 @@ _MLFLOW_AVAILABLE = _module_available("mlflow") try: import mlflow from mlflow.tracking import MlflowClient + from mlflow.tracking import context # todo: there seems to be still some remaining import error with Conda env except ImportError: _MLFLOW_AVAILABLE = False - mlflow, MlflowClient = None, None + mlflow, MlflowClient, context = None, None, None + + +# before v1.1.0 +if hasattr(context, 'resolve_tags'): + from mlflow.tracking.context import resolve_tags +# since v1.1.0 +elif hasattr(context, 'registry'): + from mlflow.tracking.context.registry import resolve_tags +else: + + def resolve_tags(tags=None): + return tags class MLFlowLogger(LightningLoggerBase): @@ -140,7 +153,7 @@ class MLFlowLogger(LightningLoggerBase): ) if self._run_id is None: - run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) + run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id return self._mlflow_client