diff --git a/CHANGELOG.md b/CHANGELOG.md index 6599fd7e55..7c4255af8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,6 +124,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added RPC and Sharded plugins ([#5732](https://github.com/PyTorchLightning/pytorch-lightning/pull/5732)) * Added missing `LightningModule`-wrapper logic to new plugins and accelerator ([#5734](https://github.com/PyTorchLightning/pytorch-lightning/pull/5734)) + +- Increased TPU check timeout from 20s to 100s ([#5598](https://github.com/PyTorchLightning/pytorch-lightning/pull/5598)) + + ### Deprecated - Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) @@ -172,12 +176,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619)) -## [1.1.5] - 2021-01-19 - -### Fixed - - - ## [1.1.5] - 2021-01-19 ### Fixed diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 210047c466..fcf56e9c67 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -22,6 +22,8 @@ from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm +#: define waiting time got checking TPU available in sec +TPU_CHECK_TIMEOUT = 100 def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -40,7 +42,7 @@ def pl_multi_process(func): queue = Queue() proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() - proc.join(20) + proc.join(TPU_CHECK_TIMEOUT) try: return queue.get_nowait() except q.Empty: diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 20b7a04e7b..aa0af1697a 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from warnings import warn warn( diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index c0349fff46..dcafa85092 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from unittest.mock import patch import pytest import pytorch_lightning.utilities.xla_device as xla_utils -from pytorch_lightning.utilities import _TPU_AVAILABLE, _XLA_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _XLA_AVAILABLE from tests.base.develop_utils import pl_multi_process_test -# lets hope that in or env we have installed XLA only for TPU devices, otherwise, -# it is testing in the cycle "if I am true test that I am true :D" -@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") +@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" assert xla_utils.XLADeviceUtils.tpu_device_exists() is None @@ -35,12 +35,12 @@ def test_tpu_device_presence(): assert xla_utils.XLADeviceUtils.tpu_device_exists() is True -def test_result_returns_within_20_seconds(): +@patch('pytorch_lightning.utilities.xla_device_utils.TPU_CHECK_TIMEOUT', 10) +def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 10 seconds""" - start = time.time() - result = xla_utils.pl_multi_process(time.sleep)(25) + result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25) end = time.time() elapsed_time = int(end - start) - assert elapsed_time <= 20 + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT assert result is False