Various test fixes (#15068)
This commit is contained in:
parent
da25d1d30d
commit
6f16e46bdb
|
@ -82,7 +82,7 @@ def _get_all_available_cuda_gpus() -> List[int]:
|
|||
def _patch_cuda_is_available() -> Generator:
|
||||
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
|
||||
possible."""
|
||||
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_13:
|
||||
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
|
||||
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
|
||||
# otherwise, patching is_available could lead to attribute errors or infinite recursion
|
||||
orig_check = torch.cuda.is_available
|
||||
|
|
|
@ -136,10 +136,10 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
|
|||
|
||||
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
|
||||
if module_available("habana_frameworks.torch.utils.library_loader"):
|
||||
from habana_frameworks.torch.utils.library_loader import is_habana_available
|
||||
from habana_frameworks.torch.utils.library_loader import is_habana_avaialble
|
||||
|
||||
if (
|
||||
is_habana_available()
|
||||
is_habana_avaialble()
|
||||
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
|
||||
and result.type() in ("torch.LongTensor", "torch.hpu.LongTensor")
|
||||
):
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.serve.servable_module_validator import ServableModule, ServableModuleValidator
|
||||
|
||||
|
@ -29,14 +30,13 @@ class ServableBoringModel(BoringModel, ServableModule):
|
|||
|
||||
|
||||
def test_servable_module_validator():
|
||||
seed_everything(42)
|
||||
model = ServableBoringModel()
|
||||
callback = ServableModuleValidator()
|
||||
callback.on_train_start(Trainer(), model)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
def test_servable_module_validator_with_trainer(tmpdir):
|
||||
seed_everything(42)
|
||||
callback = ServableModuleValidator()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
|
Loading…
Reference in New Issue