From 99f38f67d5ba54268c8a67777209b19463fe5620 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 15 Jun 2023 11:29:07 +0200 Subject: [PATCH] try to relax dependency on 3rd integrations (#17829) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/pytorch/utilities/imports.py | 19 +++++++++++++++++-- tests/tests_pytorch/test_cli.py | 6 ++++++ .../connectors/test_accelerator_connector.py | 6 ++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index ffe51116e8..fe25a37ab7 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" +import functools import sys import torch from lightning_utilities.core.imports import package_available, RequirementCache +from lightning_utilities.core.rank_zero import rank_zero_warn _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") @@ -26,5 +28,18 @@ _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") _LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai") _LIGHTNING_BAGUA_AVAILABLE = RequirementCache("lightning-bagua") -_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") -_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") + + +@functools.lru_cache(maxsize=128) +def _try_import_module(module_name: str) -> bool: + try: + __import__(module_name) + return True + # added also AttributeError fro case of impoerts like pl.LightningModule + except (ImportError, AttributeError) as err: + rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}") + return False + + +_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") and _try_import_module("lightning_graphcore") +_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") and _try_import_module("lightning_habana") diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index f33b79caa0..5681189bc8 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -14,6 +14,7 @@ import glob import inspect import json +import operator import os from contextlib import contextmanager, ExitStack, redirect_stdout from io import StringIO @@ -25,6 +26,7 @@ from unittest.mock import ANY import pytest import torch import yaml +from lightning_utilities import compare_version from lightning_utilities.test.warning import no_warning_call from tensorboard.backend.event_processing import event_accumulator from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData @@ -254,6 +256,7 @@ def test_lightning_cli_args(cleandir): assert loaded_config["trainer"] == cli_config["trainer"] +@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports") def test_lightning_env_parse(cleandir): out = StringIO() with mock.patch("sys.argv", ["", "fit", "--help"]), redirect_stdout(out), pytest.raises(SystemExit): @@ -402,6 +405,7 @@ def any_model_any_data_cli(): LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True) +@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports") def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--help"] out = StringIO() @@ -755,6 +759,7 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) +@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports") def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type(): class TestModel(BoringModel): def __init__( @@ -865,6 +870,7 @@ def test_lightning_cli_subcommands(): assert e in parameters +@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports") def test_lightning_cli_custom_subcommand(): class TestTrainer(Trainer): def foo(self, model: LightningModule, x: int, y: float = 1.0): diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 0131a1eae0..7fe31e132a 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -565,10 +565,12 @@ def mock_ipu_available(monkeypatch, value=True): def mock_hpu_available(monkeypatch, value=True): try: - import lightning_habana - except ModuleNotFoundError: + __import__("lightning_habana") + except (ImportError, AttributeError): return + import lightning_habana + monkeypatch.setattr(lightning_habana.accelerator.HPUAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning_habana.accelerator, "_HPU_AVAILABLE", value) monkeypatch.setattr(lightning_habana.strategies.parallel, "_HPU_AVAILABLE", value)