# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from types import ModuleType from unittest.mock import Mock import pytest @pytest.fixture() def mlflow_mock(monkeypatch): mlflow = ModuleType("mlflow") mlflow.set_tracking_uri = Mock() monkeypatch.setitem(sys.modules, "mlflow", mlflow) mlflow_tracking = ModuleType("tracking") mlflow_tracking.MlflowClient = Mock() mlflow_tracking.artifact_utils = Mock() monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking) mlflow_entities = ModuleType("entities") mlflow_entities.Metric = Mock() mlflow_entities.Param = Mock() mlflow_entities.time = Mock() monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities) mlflow.tracking = mlflow_tracking mlflow.entities = mlflow_entities return mlflow @pytest.fixture() def wandb_mock(monkeypatch): class RunType: # to make isinstance checks pass pass run_mock = Mock( spec=RunType, log=Mock(), config=Mock(), watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), id="run_id" ) wandb = ModuleType("wandb") wandb.init = Mock(return_value=run_mock) wandb.run = Mock() wandb.require = Mock() wandb.Api = Mock() wandb.Artifact = Mock() wandb.Image = Mock() wandb.Table = Mock() monkeypatch.setitem(sys.modules, "wandb", wandb) wandb_sdk = ModuleType("sdk") monkeypatch.setitem(sys.modules, "wandb.sdk", wandb_sdk) wandb_sdk_lib = ModuleType("lib") wandb_sdk_lib.RunDisabled = RunType monkeypatch.setitem(sys.modules, "wandb.sdk.lib", wandb_sdk_lib) wandb_wandb_run = ModuleType("wandb_run") wandb_wandb_run.Run = RunType monkeypatch.setitem(sys.modules, "wandb.wandb_run", wandb_wandb_run) wandb.sdk = wandb_sdk wandb.sdk.lib = wandb_sdk_lib wandb.wandb_run = wandb_wandb_run return wandb @pytest.fixture() def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) comet.Experiment = Mock() comet.ExistingExperiment = Mock() comet.OfflineExperiment = Mock() comet.API = Mock() comet.config = Mock() comet_api = ModuleType("api") comet_api.API = Mock() monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api) comet.api = comet_api return comet