[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn (#6719)

* update_logic

* update

* Update tests/utilities/test_xla_device_utils.py

* Update pytorch_lightning/utilities/xla_device.py

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

* Update pytorch_lightning/utilities/xla_device.py

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

* update test

* Update tests/utilities/test_xla_device_utils.py

* update

* Apply fix

* Docstring

* flake8

* update

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-03-29 18:59:20 +01:00 committed by GitHub
parent 5b5a5cc80b
commit 3a4c4246ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 34 deletions

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import numpy
import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,

View File

@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue
import torch
import torch.multiprocessing as mp
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 100
TPU_CHECK_TIMEOUT = 25
def inner_f(queue, func, *args, **kwargs): # pragma: no cover
@ -55,23 +58,10 @@ def pl_multi_process(func):
class XLADeviceUtils:
"""Used to detect the type of XLA device"""
TPU_AVAILABLE = None
@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
if _XLA_AVAILABLE:
return xm.xla_device_hw(device)
_TPU_AVAILABLE = False
@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
"""
Check if device is TPU
@ -79,10 +69,18 @@ class XLADeviceUtils:
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if _XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"
def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)
smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()
@staticmethod
def xla_available() -> bool:
@ -102,6 +100,14 @@ class XLADeviceUtils:
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
XLADeviceUtils._TPU_AVAILABLE = True
if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:
XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()
if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'
return XLADeviceUtils._TPU_AVAILABLE

View File

@ -19,28 +19,35 @@ import pytest
import pytorch_lightning.utilities.xla_device as xla_utils
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
"""Check tpu_device_exists returns False when torch_xla is not available"""
assert not xla_utils.XLADeviceUtils.tpu_device_exists()
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
assert xla_utils.XLADeviceUtils.tpu_device_exists()
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
def sleep_fn(sleep_time: float) -> bool:
time.sleep(sleep_time)
return True
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
def test_result_returns_within_timeout_seconds():
"""Check that pl_multi_process returns within 10 seconds"""
"""Check that pl_multi_process returns within 3 seconds"""
fn = xla_utils.pl_multi_process(sleep_fn)
start = time.time()
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
end = time.time()
elapsed_time = int(end - start)
assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
assert result is False
assert result