Various test fixes (#15068)

This commit is contained in:
Carlos Mocholí 2022-10-11 09:47:16 +02:00 committed by GitHub
parent da25d1d30d
commit 6f16e46bdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 6 deletions

View File

@ -82,7 +82,7 @@ def _get_all_available_cuda_gpus() -> List[int]:
def _patch_cuda_is_available() -> Generator:
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
possible."""
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_13:
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
orig_check = torch.cuda.is_available

View File

@ -136,10 +136,10 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
if module_available("habana_frameworks.torch.utils.library_loader"):
from habana_frameworks.torch.utils.library_loader import is_habana_available
from habana_frameworks.torch.utils.library_loader import is_habana_avaialble
if (
is_habana_available()
is_habana_avaialble()
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
and result.type() in ("torch.LongTensor", "torch.hpu.LongTensor")
):

View File

@ -1,8 +1,9 @@
from typing import Dict
import pytest
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.serve.servable_module_validator import ServableModule, ServableModuleValidator
@ -29,14 +30,13 @@ class ServableBoringModel(BoringModel, ServableModule):
def test_servable_module_validator():
seed_everything(42)
model = ServableBoringModel()
callback = ServableModuleValidator()
callback.on_train_start(Trainer(), model)
@pytest.mark.flaky(reruns=3)
def test_servable_module_validator_with_trainer(tmpdir):
seed_everything(42)
callback = ServableModuleValidator()
trainer = Trainer(
default_root_dir=tmpdir,