Remove TPU Availability check from parse devices (#12326)
* Remove TPU Availability check from parse devices * Update tests
This commit is contained in:
parent
4fe0076a1a
commit
041da417db
|
@ -17,7 +17,6 @@ import torch
|
|||
|
||||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
|
||||
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
@ -122,7 +121,7 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
|
|||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If TPU cores aren't 1 or 8 cores, or no TPU devices are found
|
||||
If TPU cores aren't 1, 8 or [<1-8>]
|
||||
"""
|
||||
_check_data_type(tpu_cores)
|
||||
|
||||
|
@ -132,9 +131,6 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
|
|||
if not _tpu_cores_valid(tpu_cores):
|
||||
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
|
||||
|
||||
if tpu_cores is not None and not _TPU_AVAILABLE:
|
||||
raise MisconfigurationException("No TPU devices were found.")
|
||||
|
||||
return tpu_cores
|
||||
|
||||
|
||||
|
|
|
@ -446,8 +446,7 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
|
|||
|
||||
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
|
||||
def test_ipython_compatible_strategy_tpu(mock_devices, mock_tpu_acc_avail, monkeypatch):
|
||||
def test_ipython_compatible_strategy_tpu(mock_tpu_acc_avail, monkeypatch):
|
||||
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
|
||||
trainer = Trainer(accelerator="tpu")
|
||||
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
|
||||
|
@ -894,8 +893,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
|
|||
|
||||
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
|
||||
def test_unsupported_tpu_choice(mock_devices, mock_tpu_acc_avail):
|
||||
def test_unsupported_tpu_choice(mock_tpu_acc_avail):
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
|
||||
Trainer(accelerator="tpu", precision=64)
|
||||
|
|
|
@ -1139,9 +1139,8 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs):
|
|||
|
||||
|
||||
def test_trainer_tpu_cores(monkeypatch):
|
||||
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
|
||||
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "parse_devices", lambda: 8)
|
||||
trainer = Trainer(accelerator="TPU", devices=8)
|
||||
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda _: True)
|
||||
trainer = Trainer(accelerator="tpu", devices=8)
|
||||
with pytest.deprecated_call(
|
||||
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
|
||||
"Please use `Trainer.num_devices` instead."
|
||||
|
|
Loading…
Reference in New Issue