try to relax dependency on 3rd integrations (#17829)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ca30fd7752
commit
99f38f67d5
|
@ -12,10 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities."""
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import package_available, RequirementCache
|
||||
from lightning_utilities.core.rank_zero import rank_zero_warn
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
|
||||
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
|
||||
|
@ -26,5 +28,18 @@ _OMEGACONF_AVAILABLE = package_available("omegaconf")
|
|||
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
|
||||
_LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai")
|
||||
_LIGHTNING_BAGUA_AVAILABLE = RequirementCache("lightning-bagua")
|
||||
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana")
|
||||
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _try_import_module(module_name: str) -> bool:
|
||||
try:
|
||||
__import__(module_name)
|
||||
return True
|
||||
# added also AttributeError fro case of impoerts like pl.LightningModule
|
||||
except (ImportError, AttributeError) as err:
|
||||
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}")
|
||||
return False
|
||||
|
||||
|
||||
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") and _try_import_module("lightning_graphcore")
|
||||
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") and _try_import_module("lightning_habana")
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import operator
|
||||
import os
|
||||
from contextlib import contextmanager, ExitStack, redirect_stdout
|
||||
from io import StringIO
|
||||
|
@ -25,6 +26,7 @@ from unittest.mock import ANY
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from lightning_utilities import compare_version
|
||||
from lightning_utilities.test.warning import no_warning_call
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
|
||||
|
@ -254,6 +256,7 @@ def test_lightning_cli_args(cleandir):
|
|||
assert loaded_config["trainer"] == cli_config["trainer"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
|
||||
def test_lightning_env_parse(cleandir):
|
||||
out = StringIO()
|
||||
with mock.patch("sys.argv", ["", "fit", "--help"]), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
|
@ -402,6 +405,7 @@ def any_model_any_data_cli():
|
|||
LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
|
||||
def test_lightning_cli_help():
|
||||
cli_args = ["any.py", "fit", "--help"]
|
||||
out = StringIO()
|
||||
|
@ -755,6 +759,7 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base
|
|||
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
||||
|
||||
|
||||
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
|
||||
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
|
||||
class TestModel(BoringModel):
|
||||
def __init__(
|
||||
|
@ -865,6 +870,7 @@ def test_lightning_cli_subcommands():
|
|||
assert e in parameters
|
||||
|
||||
|
||||
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
|
||||
def test_lightning_cli_custom_subcommand():
|
||||
class TestTrainer(Trainer):
|
||||
def foo(self, model: LightningModule, x: int, y: float = 1.0):
|
||||
|
|
|
@ -565,10 +565,12 @@ def mock_ipu_available(monkeypatch, value=True):
|
|||
|
||||
def mock_hpu_available(monkeypatch, value=True):
|
||||
try:
|
||||
import lightning_habana
|
||||
except ModuleNotFoundError:
|
||||
__import__("lightning_habana")
|
||||
except (ImportError, AttributeError):
|
||||
return
|
||||
|
||||
import lightning_habana
|
||||
|
||||
monkeypatch.setattr(lightning_habana.accelerator.HPUAccelerator, "is_available", lambda: value)
|
||||
monkeypatch.setattr(lightning_habana.accelerator, "_HPU_AVAILABLE", value)
|
||||
monkeypatch.setattr(lightning_habana.strategies.parallel, "_HPU_AVAILABLE", value)
|
||||
|
|
Loading…
Reference in New Issue