diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index c7cc9a35c1..1e8f82a057 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -82,7 +82,7 @@ def _get_all_available_cuda_gpus() -> List[int]: def _patch_cuda_is_available() -> Generator: """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible.""" - if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_13: + if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding # otherwise, patching is_available could lead to attribute errors or infinite recursion orig_check = torch.cuda.is_available diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 767de75411..19c3770084 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -136,10 +136,10 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un # WA for HPU. HPU doesn't support Long types, forcefully set it to float if module_available("habana_frameworks.torch.utils.library_loader"): - from habana_frameworks.torch.utils.library_loader import is_habana_available + from habana_frameworks.torch.utils.library_loader import is_habana_avaialble if ( - is_habana_available() + is_habana_avaialble() and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" and result.type() in ("torch.LongTensor", "torch.hpu.LongTensor") ): diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index 9072fe79ce..dd578c5907 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,8 +1,9 @@ from typing import Dict +import pytest import torch -from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.serve.servable_module_validator import ServableModule, ServableModuleValidator @@ -29,14 +30,13 @@ class ServableBoringModel(BoringModel, ServableModule): def test_servable_module_validator(): - seed_everything(42) model = ServableBoringModel() callback = ServableModuleValidator() callback.on_train_start(Trainer(), model) +@pytest.mark.flaky(reruns=3) def test_servable_module_validator_with_trainer(tmpdir): - seed_everything(42) callback = ServableModuleValidator() trainer = Trainer( default_root_dir=tmpdir,