96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import sys
|
|
from types import ModuleType
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture()
|
|
def mlflow_mock(monkeypatch):
|
|
mlflow = ModuleType("mlflow")
|
|
mlflow.set_tracking_uri = Mock()
|
|
monkeypatch.setitem(sys.modules, "mlflow", mlflow)
|
|
|
|
mlflow_tracking = ModuleType("tracking")
|
|
mlflow_tracking.MlflowClient = Mock()
|
|
mlflow_tracking.artifact_utils = Mock()
|
|
monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking)
|
|
|
|
mlflow_entities = ModuleType("entities")
|
|
mlflow_entities.Metric = Mock()
|
|
mlflow_entities.Param = Mock()
|
|
mlflow_entities.time = Mock()
|
|
monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities)
|
|
|
|
mlflow.tracking = mlflow_tracking
|
|
mlflow.entities = mlflow_entities
|
|
return mlflow
|
|
|
|
|
|
@pytest.fixture()
|
|
def wandb_mock(monkeypatch):
|
|
class RunType: # to make isinstance checks pass
|
|
pass
|
|
|
|
run_mock = Mock(
|
|
spec=RunType, log=Mock(), config=Mock(), watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), id="run_id"
|
|
)
|
|
|
|
wandb = ModuleType("wandb")
|
|
wandb.init = Mock(return_value=run_mock)
|
|
wandb.run = Mock()
|
|
wandb.require = Mock()
|
|
wandb.Api = Mock()
|
|
wandb.Artifact = Mock()
|
|
wandb.Image = Mock()
|
|
wandb.Table = Mock()
|
|
monkeypatch.setitem(sys.modules, "wandb", wandb)
|
|
|
|
wandb_sdk = ModuleType("sdk")
|
|
monkeypatch.setitem(sys.modules, "wandb.sdk", wandb_sdk)
|
|
|
|
wandb_sdk_lib = ModuleType("lib")
|
|
wandb_sdk_lib.RunDisabled = RunType
|
|
monkeypatch.setitem(sys.modules, "wandb.sdk.lib", wandb_sdk_lib)
|
|
|
|
wandb_wandb_run = ModuleType("wandb_run")
|
|
wandb_wandb_run.Run = RunType
|
|
monkeypatch.setitem(sys.modules, "wandb.wandb_run", wandb_wandb_run)
|
|
|
|
wandb.sdk = wandb_sdk
|
|
wandb.sdk.lib = wandb_sdk_lib
|
|
wandb.wandb_run = wandb_wandb_run
|
|
return wandb
|
|
|
|
|
|
@pytest.fixture()
|
|
def comet_mock(monkeypatch):
|
|
comet = ModuleType("comet_ml")
|
|
monkeypatch.setitem(sys.modules, "comet_ml", comet)
|
|
|
|
comet.Experiment = Mock()
|
|
comet.ExistingExperiment = Mock()
|
|
comet.OfflineExperiment = Mock()
|
|
comet.API = Mock()
|
|
comet.config = Mock()
|
|
|
|
comet_api = ModuleType("api")
|
|
comet_api.API = Mock()
|
|
monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api)
|
|
|
|
comet.api = comet_api
|
|
return comet
|