[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn (#6719)
* update_logic * update * Update tests/utilities/test_xla_device_utils.py * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> * update test * Update tests/utilities/test_xla_device_utils.py * update * Apply fix * Docstring * flake8 * update Co-authored-by: Your Name <you@example.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
5b5a5cc80b
commit
3a4c4246ee
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities"""
|
||||
import numpy
|
||||
|
||||
import numpy
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
|
||||
from pytorch_lightning.utilities.distributed import ( # noqa: F401
|
||||
AllGatherGrad,
|
||||
|
|
|
@ -12,18 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import os
|
||||
import queue as q
|
||||
import traceback
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
|
||||
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
#: define waiting time got checking TPU available in sec
|
||||
TPU_CHECK_TIMEOUT = 100
|
||||
TPU_CHECK_TIMEOUT = 25
|
||||
|
||||
|
||||
def inner_f(queue, func, *args, **kwargs): # pragma: no cover
|
||||
|
@ -55,23 +58,10 @@ def pl_multi_process(func):
|
|||
class XLADeviceUtils:
|
||||
"""Used to detect the type of XLA device"""
|
||||
|
||||
TPU_AVAILABLE = None
|
||||
|
||||
@staticmethod
|
||||
def _fetch_xla_device_type(device: torch.device) -> str:
|
||||
"""
|
||||
Returns XLA device type
|
||||
|
||||
Args:
|
||||
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
|
||||
|
||||
Return:
|
||||
Returns a str of the device hardware type. i.e TPU
|
||||
"""
|
||||
if _XLA_AVAILABLE:
|
||||
return xm.xla_device_hw(device)
|
||||
_TPU_AVAILABLE = False
|
||||
|
||||
@staticmethod
|
||||
@pl_multi_process
|
||||
def _is_device_tpu() -> bool:
|
||||
"""
|
||||
Check if device is TPU
|
||||
|
@ -79,10 +69,18 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
A boolean value indicating if the xla device is a TPU device or not
|
||||
"""
|
||||
if _XLA_AVAILABLE:
|
||||
device = xm.xla_device()
|
||||
device_type = XLADeviceUtils._fetch_xla_device_type(device)
|
||||
return device_type == "TPU"
|
||||
|
||||
def _fn(_: int, mp_queue):
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
mp_queue.put(device.type == 'xla')
|
||||
except Exception:
|
||||
mp_queue.put(False)
|
||||
|
||||
smp = mp.get_context("spawn")
|
||||
queue = smp.SimpleQueue()
|
||||
xmp.spawn(_fn, args=(queue, ), nprocs=1)
|
||||
return queue.get()
|
||||
|
||||
@staticmethod
|
||||
def xla_available() -> bool:
|
||||
|
@ -102,6 +100,14 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
A boolean value indicating if a TPU device exists on the system
|
||||
"""
|
||||
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
|
||||
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
|
||||
return XLADeviceUtils.TPU_AVAILABLE
|
||||
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
|
||||
XLADeviceUtils._TPU_AVAILABLE = True
|
||||
|
||||
if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:
|
||||
|
||||
XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()
|
||||
|
||||
if XLADeviceUtils._TPU_AVAILABLE:
|
||||
os.environ["PL_TPU_AVAILABLE"] = '1'
|
||||
|
||||
return XLADeviceUtils._TPU_AVAILABLE
|
||||
|
|
|
@ -19,28 +19,35 @@ import pytest
|
|||
import pytorch_lightning.utilities.xla_device as xla_utils
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.utils import pl_multi_process_test
|
||||
|
||||
|
||||
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
def test_tpu_device_absence():
|
||||
"""Check tpu_device_exists returns None when torch_xla is not available"""
|
||||
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
|
||||
"""Check tpu_device_exists returns False when torch_xla is not available"""
|
||||
assert not xla_utils.XLADeviceUtils.tpu_device_exists()
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@pl_multi_process_test
|
||||
def test_tpu_device_presence():
|
||||
"""Check tpu_device_exists returns True when TPU is available"""
|
||||
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
|
||||
assert xla_utils.XLADeviceUtils.tpu_device_exists()
|
||||
|
||||
|
||||
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
|
||||
def sleep_fn(sleep_time: float) -> bool:
|
||||
time.sleep(sleep_time)
|
||||
return True
|
||||
|
||||
|
||||
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
|
||||
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
|
||||
def test_result_returns_within_timeout_seconds():
|
||||
"""Check that pl_multi_process returns within 10 seconds"""
|
||||
"""Check that pl_multi_process returns within 3 seconds"""
|
||||
fn = xla_utils.pl_multi_process(sleep_fn)
|
||||
|
||||
start = time.time()
|
||||
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
|
||||
result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
|
||||
end = time.time()
|
||||
elapsed_time = int(end - start)
|
||||
|
||||
assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
|
||||
assert result is False
|
||||
assert result
|
||||
|
|
Loading…
Reference in New Issue