diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 26a5b983a5..55e2ced03a 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -86,7 +86,7 @@ subprojects: - ".github/workflows/tpu-tests.yml" - "tests/tests_pytorch/run_tpu_tests.sh" checks: - #- "test-on-tpus (pytorch, xrt)" + - "test-on-tpus (pytorch, xrt)" - "test-on-tpus (pytorch, pjrt)" - id: "fabric: Docs" diff --git a/src/lightning/fabric/accelerators/tpu.py b/src/lightning/fabric/accelerators/tpu.py index b40dbf6f1f..021012d89e 100644 --- a/src/lightning/fabric/accelerators/tpu.py +++ b/src/lightning/fabric/accelerators/tpu.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from multiprocessing import Process, Queue -from typing import Any, Callable, Dict, List, Union +from typing import Any, Dict, List, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -69,11 +68,15 @@ class TPUAccelerator(Accelerator): @functools.lru_cache(maxsize=1) def auto_device_count() -> int: """Get the devices when set to auto.""" + if not _XLA_AVAILABLE: + return 0 import torch_xla.core.xla_env_vars as xenv from torch_xla.experimental import pjrt, tpu from torch_xla.utils.utils import getenv_as if pjrt.using_pjrt(): + if _XLA_GREATER_EQUAL_2_1: + return tpu.num_available_devices() device_count_on_version = {2: 8, 3: 8, 4: 4} return device_count_on_version.get(tpu.version(), 8) else: @@ -82,21 +85,7 @@ class TPUAccelerator(Accelerator): @staticmethod @functools.lru_cache(maxsize=1) def is_available() -> bool: - if not _XLA_AVAILABLE: - return False - queue: Queue = Queue() - proc = Process(target=_inner_f, args=(queue, _has_tpu_device)) - proc.start() - proc.join(TPU_CHECK_TIMEOUT) - if proc.is_alive(): - proc.terminate() - proc.join() - # if the timeout is triggered, fail to avoid silently running on a different accelerator - raise TimeoutError( - "Timed out waiting to check whether a TPU is available. You can increase the TPU_CHECK_TIMEOUT value." - f" Currently {TPU_CHECK_TIMEOUT}" - ) - return queue.get_nowait() + return TPUAccelerator.auto_device_count() > 0 @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: @@ -107,36 +96,9 @@ class TPUAccelerator(Accelerator): ) -# define TPU availability timeout in seconds -TPU_CHECK_TIMEOUT = 60 - - -def _inner_f(queue: Queue, func: Callable) -> None: - res = func() - queue.put(res) - - -def _has_tpu_device() -> bool: - """Check if TPU devices are available. - - Return: - A boolean value indicating if TPU devices are available - """ - if not _XLA_AVAILABLE: - return False - import torch_xla.core.xla_model as xm - from torch_xla.experimental import pjrt - - if pjrt.using_pjrt(): - return bool(xm.get_xla_supported_devices("TPU")) - # For the TPU Pod training process, for example, if we have TPU v3-32 with 4 VMs, the world size would be 4 and as - # we would have to use `torch_xla.distributed.xla_dist` for multiple VMs and TPU_CONFIG won't be available, running - # `xm.get_xla_supported_devices("TPU")` won't be possible. - return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU")) - - # PJRT support requires this minimum version _XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") +_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: diff --git a/tests/tests_fabric/accelerators/test_tpu.py b/tests/tests_fabric/accelerators/test_tpu.py index bee26df79b..9eea23d900 100644 --- a/tests/tests_fabric/accelerators/test_tpu.py +++ b/tests/tests_fabric/accelerators/test_tpu.py @@ -14,7 +14,7 @@ import pytest -from lightning.fabric.accelerators.tpu import TPUAccelerator +from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator from tests_fabric.helpers.runif import RunIf @@ -25,9 +25,10 @@ def test_auto_device_count(): assert TPUAccelerator.auto_device_count() > 1 -@RunIf(tpu=True) -def test_availability(): - assert TPUAccelerator.is_available() +@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") +def test_tpu_device_absence(): + """Check `is_available` returns True when TPU is available.""" + assert not TPUAccelerator.is_available() @pytest.mark.parametrize("devices", (1, 8)) diff --git a/tests/tests_fabric/utilities/test_xla_device_utils.py b/tests/tests_fabric/utilities/test_xla_device_utils.py deleted file mode 100644 index 3411d4564f..0000000000 --- a/tests/tests_fabric/utilities/test_xla_device_utils.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import time - -import pytest - -from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator -from tests_fabric.helpers.runif import RunIf - - -@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") -def test_tpu_device_absence(): - """Check `is_available` returns True when TPU is available.""" - assert not TPUAccelerator.is_available() - - -@RunIf(tpu=True) -def test_tpu_device_presence(): - """Check `is_available` returns True when TPU is available.""" - assert TPUAccelerator.is_available() - - -def _t1_5(): - time.sleep(1.5) - return True - - -# this test runs very slowly on these platforms -@RunIf(skip_windows=True) -@pytest.mark.skipif(sys.platform == "darwin", reason="Times out") -def test_result_returns_within_timeout_seconds(monkeypatch): - """Check that the TPU availability process launch returns within 3 seconds.""" - from lightning.fabric.accelerators import tpu - - timeout = 3 - monkeypatch.setattr(tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(tpu, "TPU_CHECK_TIMEOUT", timeout) - monkeypatch.setattr(tpu, "_has_tpu_device", _t1_5) - tpu.TPUAccelerator.is_available.cache_clear() - - start = time.monotonic() - - result = tpu.TPUAccelerator.is_available() - - end = time.monotonic() - elapsed_time = end - start - - # around 1.5 but definitely not 3 (timeout time) - assert 1 < elapsed_time < 2, elapsed_time - assert result - - tpu.TPUAccelerator.is_available.cache_clear() - - -def _t3(): - time.sleep(3) - return True - - -def test_timeout_triggered(monkeypatch): - """Check that the TPU availability process launch returns within 3 seconds.""" - from lightning.fabric.accelerators import tpu - - timeout = 1.5 - monkeypatch.setattr(tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(tpu, "TPU_CHECK_TIMEOUT", timeout) - monkeypatch.setattr(tpu, "_has_tpu_device", _t3) - tpu.TPUAccelerator.is_available.cache_clear() - - start = time.monotonic() - - with pytest.raises(TimeoutError, match="Timed out waiting"): - tpu.TPUAccelerator.is_available() - - end = time.monotonic() - elapsed_time = end - start - - # around 1.5 but definitely not 3 (fn time) - assert 1 < elapsed_time < 2, elapsed_time - - tpu.TPUAccelerator.is_available.cache_clear()