try to relax dependency on 3rd integrations (#17829)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2023-06-15 11:29:07 +02:00 committed by GitHub
parent ca30fd7752
commit 99f38f67d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 4 deletions

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""
import functools
import sys
import torch
from lightning_utilities.core.imports import package_available, RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
@ -26,5 +28,18 @@ _OMEGACONF_AVAILABLE = package_available("omegaconf")
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai")
_LIGHTNING_BAGUA_AVAILABLE = RequirementCache("lightning-bagua")
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana")
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore")
@functools.lru_cache(maxsize=128)
def _try_import_module(module_name: str) -> bool:
try:
__import__(module_name)
return True
# added also AttributeError fro case of impoerts like pl.LightningModule
except (ImportError, AttributeError) as err:
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}")
return False
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") and _try_import_module("lightning_graphcore")
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") and _try_import_module("lightning_habana")

View File

@ -14,6 +14,7 @@
import glob
import inspect
import json
import operator
import os
from contextlib import contextmanager, ExitStack, redirect_stdout
from io import StringIO
@ -25,6 +26,7 @@ from unittest.mock import ANY
import pytest
import torch
import yaml
from lightning_utilities import compare_version
from lightning_utilities.test.warning import no_warning_call
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
@ -254,6 +256,7 @@ def test_lightning_cli_args(cleandir):
assert loaded_config["trainer"] == cli_config["trainer"]
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_env_parse(cleandir):
out = StringIO()
with mock.patch("sys.argv", ["", "fit", "--help"]), redirect_stdout(out), pytest.raises(SystemExit):
@ -402,6 +405,7 @@ def any_model_any_data_cli():
LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True)
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_cli_help():
cli_args = ["any.py", "fit", "--help"]
out = StringIO()
@ -755,6 +759,7 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
@ -865,6 +870,7 @@ def test_lightning_cli_subcommands():
assert e in parameters
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_cli_custom_subcommand():
class TestTrainer(Trainer):
def foo(self, model: LightningModule, x: int, y: float = 1.0):

View File

@ -565,10 +565,12 @@ def mock_ipu_available(monkeypatch, value=True):
def mock_hpu_available(monkeypatch, value=True):
try:
import lightning_habana
except ModuleNotFoundError:
__import__("lightning_habana")
except (ImportError, AttributeError):
return
import lightning_habana
monkeypatch.setattr(lightning_habana.accelerator.HPUAccelerator, "is_available", lambda: value)
monkeypatch.setattr(lightning_habana.accelerator, "_HPU_AVAILABLE", value)
monkeypatch.setattr(lightning_habana.strategies.parallel, "_HPU_AVAILABLE", value)