lightning/tests/tests_pytorch/loggers/conftest.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

96 lines
2.8 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from types import ModuleType
from unittest.mock import Mock
import pytest
@pytest.fixture()
def mlflow_mock(monkeypatch):
mlflow = ModuleType("mlflow")
mlflow.set_tracking_uri = Mock()
monkeypatch.setitem(sys.modules, "mlflow", mlflow)
mlflow_tracking = ModuleType("tracking")
mlflow_tracking.MlflowClient = Mock()
mlflow_tracking.artifact_utils = Mock()
monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking)
mlflow_entities = ModuleType("entities")
mlflow_entities.Metric = Mock()
mlflow_entities.Param = Mock()
mlflow_entities.time = Mock()
monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities)
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
return mlflow
@pytest.fixture()
def wandb_mock(monkeypatch):
class RunType: # to make isinstance checks pass
pass
run_mock = Mock(
spec=RunType, log=Mock(), config=Mock(), watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), id="run_id"
)
wandb = ModuleType("wandb")
wandb.init = Mock(return_value=run_mock)
wandb.run = Mock()
wandb.require = Mock()
wandb.Api = Mock()
wandb.Artifact = Mock()
wandb.Image = Mock()
wandb.Table = Mock()
monkeypatch.setitem(sys.modules, "wandb", wandb)
wandb_sdk = ModuleType("sdk")
monkeypatch.setitem(sys.modules, "wandb.sdk", wandb_sdk)
wandb_sdk_lib = ModuleType("lib")
wandb_sdk_lib.RunDisabled = RunType
monkeypatch.setitem(sys.modules, "wandb.sdk.lib", wandb_sdk_lib)
wandb_wandb_run = ModuleType("wandb_run")
wandb_wandb_run.Run = RunType
monkeypatch.setitem(sys.modules, "wandb.wandb_run", wandb_wandb_run)
wandb.sdk = wandb_sdk
wandb.sdk.lib = wandb_sdk_lib
wandb.wandb_run = wandb_wandb_run
return wandb
@pytest.fixture()
def comet_mock(monkeypatch):
comet = ModuleType("comet_ml")
monkeypatch.setitem(sys.modules, "comet_ml", comet)
comet.Experiment = Mock()
comet.ExistingExperiment = Mock()
comet.OfflineExperiment = Mock()
comet.API = Mock()
comet.config = Mock()
comet_api = ModuleType("api")
comet_api.API = Mock()
monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api)
comet.api = comet_api
return comet