Update mlflow with using resolve_tags (#6746)
* Update mlflow.py #6745 adds additional info about the run, as in the native API * Update mlflow.py trying to fix some backward compatibility issues with `resolve_tags` * wip on backward compatibility added a default for `getattr` in case the `registry` object exists, but has no proper attribute (weird case but who knows...) * fix pep * impoert * fix registry import * try fix failing tests removed the first if statement, so that `resolve_tags` would be defined either case * fix formatting Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
eb15abcd82
commit
3007872d01
|
@ -30,10 +30,23 @@ _MLFLOW_AVAILABLE = _module_available("mlflow")
|
|||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.tracking import context
|
||||
# todo: there seems to be still some remaining import error with Conda env
|
||||
except ImportError:
|
||||
_MLFLOW_AVAILABLE = False
|
||||
mlflow, MlflowClient = None, None
|
||||
mlflow, MlflowClient, context = None, None, None
|
||||
|
||||
|
||||
# before v1.1.0
|
||||
if hasattr(context, 'resolve_tags'):
|
||||
from mlflow.tracking.context import resolve_tags
|
||||
# since v1.1.0
|
||||
elif hasattr(context, 'registry'):
|
||||
from mlflow.tracking.context.registry import resolve_tags
|
||||
else:
|
||||
|
||||
def resolve_tags(tags=None):
|
||||
return tags
|
||||
|
||||
|
||||
class MLFlowLogger(LightningLoggerBase):
|
||||
|
@ -140,7 +153,7 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
)
|
||||
|
||||
if self._run_id is None:
|
||||
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)
|
||||
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
|
||||
self._run_id = run.info.run_id
|
||||
return self._mlflow_client
|
||||
|
||||
|
|
Loading…
Reference in New Issue