Update mlflow with using resolve_tags (#6746)

* Update mlflow.py

#6745 adds additional info about the run, as in the native API

* Update mlflow.py

trying to fix some backward compatibility issues with `resolve_tags`

* wip on backward compatibility

added a default for `getattr` in case the `registry` object exists, but has no proper attribute (weird case but who knows...)

* fix pep

* impoert

* fix registry import

* try fix failing tests

removed the first if statement, so that `resolve_tags` would be defined either case

* fix formatting

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
Oleg 2021-04-08 16:45:23 +07:00 committed by GitHub
parent eb15abcd82
commit 3007872d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 2 deletions

View File

@ -30,10 +30,23 @@ _MLFLOW_AVAILABLE = _module_available("mlflow")
try:
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.tracking import context
# todo: there seems to be still some remaining import error with Conda env
except ImportError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient = None, None
mlflow, MlflowClient, context = None, None, None
# before v1.1.0
if hasattr(context, 'resolve_tags'):
from mlflow.tracking.context import resolve_tags
# since v1.1.0
elif hasattr(context, 'registry'):
from mlflow.tracking.context.registry import resolve_tags
else:
def resolve_tags(tags=None):
return tags
class MLFlowLogger(LightningLoggerBase):
@ -140,7 +153,7 @@ class MLFlowLogger(LightningLoggerBase):
)
if self._run_id is None:
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
return self._mlflow_client