diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cbd0e2309e..ad0779d88b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -74,3 +74,8 @@ class Accelerator(ABC): @abstractmethod def is_available() -> bool: """Detect if the hardware is available.""" + + @staticmethod + @abstractmethod + def name() -> str: + """Name of the Accelerator.""" diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index d586478619..a027e7db6e 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -62,3 +62,8 @@ class CPUAccelerator(Accelerator): def is_available() -> bool: """CPU is always available for execution.""" return True + + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "cpu" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index f9181e8802..529d067025 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -93,6 +93,11 @@ class GPUAccelerator(Accelerator): def is_available() -> bool: return torch.cuda.device_count() > 0 + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "gpu" + def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 2ac1c79461..1e8b2bc27f 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -46,3 +46,8 @@ class IPUAccelerator(Accelerator): @staticmethod def is_available() -> bool: return _IPU_AVAILABLE + + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "ipu" diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index cd84cccd8b..dfdc950e70 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -64,3 +64,8 @@ class TPUAccelerator(Accelerator): @staticmethod def is_available() -> bool: return _TPU_AVAILABLE + + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "tpu" diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index b31f598d82..7038dffb98 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -465,16 +465,15 @@ class AcceleratorConnector: return "cpu" def _set_parallel_devices_and_init_accelerator(self) -> None: - # TODO add device availability check + ACCELERATORS = { + "cpu": CPUAccelerator, + "gpu": GPUAccelerator, + "tpu": TPUAccelerator, + "ipu": IPUAccelerator, + } if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: - ACCELERATORS = { - "cpu": CPUAccelerator, - "gpu": GPUAccelerator, - "tpu": TPUAccelerator, - "ipu": IPUAccelerator, - } assert self._accelerator_flag is not None self._accelerator_flag = self._accelerator_flag.lower() if self._accelerator_flag not in ACCELERATORS: @@ -485,6 +484,15 @@ class AcceleratorConnector: accelerator_class = ACCELERATORS[self._accelerator_flag] self.accelerator = accelerator_class() # type: ignore[abstract] + if not self.accelerator.is_available(): + available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()] + raise MisconfigurationException( + f"{self.accelerator.__class__.__qualname__} can not run on your system" + f" since {self.accelerator.name().upper()}s are not available." + " The following accelerator(s) is available and can be passed into" + f" `accelerator` argument of `Trainer`: {available_accelerator}." + ) + self._set_devices_flag_if_auto_passed() self._gpus = self._devices_flag if not self._gpus else self._gpus diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 4dd2ea9847..0a41feaec2 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -358,6 +358,10 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock): def is_available() -> bool: return True + @staticmethod + def name() -> str: + return "custom_acc_name" + class Prec(PrecisionPlugin): pass @@ -429,8 +433,9 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) -def test_ipython_compatible_strategy_tpu(_, monkeypatch): +def test_ipython_compatible_strategy_tpu(mock_devices, mock_tpu_acc_avail, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) trainer = Trainer(accelerator="tpu") assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible @@ -480,9 +485,10 @@ def test_accelerator_cpu(_): with pytest.raises(MisconfigurationException, match="You requested gpu:"): trainer = Trainer(gpus=1) - # TODO enable this test when add device availability check - # with pytest.raises(MisconfigurationException, match="You requested gpu, but gpu is not available"): - # trainer = Trainer(accelerator="gpu") + with pytest.raises( + MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available." + ): + trainer = Trainer(accelerator="gpu") with pytest.raises(MisconfigurationException, match="You requested gpu:"): trainer = Trainer(accelerator="cpu", gpus=1) @@ -899,8 +905,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock assert trainer.strategy.local_rank == 0 +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) -def test_unsupported_tpu_choice(mock_devices): +def test_unsupported_tpu_choice(mock_devices, mock_tpu_acc_avail): with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) @@ -915,7 +922,8 @@ def test_unsupported_tpu_choice(mock_devices): Trainer(accelerator="tpu", precision=16, amp_backend="apex", strategy="single_device") -def test_unsupported_ipu_choice(monkeypatch): +@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) +def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): import pytorch_lightning.strategies.ipu as ipu import pytorch_lightning.utilities.imports as imports diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 473546696e..ef8780f698 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -44,6 +44,10 @@ def test_pluggable_accelerator(): def is_available(): return True + @staticmethod + def name(): + return "custom_acc_name" + trainer = Trainer(accelerator=TestAccelerator(), devices=2, strategy="ddp") assert isinstance(trainer.accelerator, TestAccelerator) assert isinstance(trainer.strategy, DDPStrategy) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index a7829690eb..11d876a3ed 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -13,6 +13,7 @@ # limitations under the License. import os from typing import Optional +from unittest import mock import pytest import torch @@ -97,7 +98,8 @@ class IPUClassificationModel(ClassificationModel): @pytest.mark.skipif(_IPU_AVAILABLE, reason="test requires non-IPU machine") -def test_fail_if_no_ipus(tmpdir): +@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) +def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir): with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"): Trainer(default_root_dir=tmpdir, ipus=1) diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index a8957d46fb..f482c1ff97 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -58,7 +58,8 @@ def environment_combinations(): "strategy_cls", [DDPStrategy, DDPShardedStrategy, DDP2Strategy, pytest.param(DeepSpeedStrategy, marks=RunIf(deepspeed=True))], ) -def test_ranks_available_manual_strategy_selection(strategy_cls): +@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True) +def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strategy_cls): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 for cluster, variables, expected in environment_combinations(): diff --git a/tests/trainer/properties/test_estimated_stepping_batches.py b/tests/trainer/properties/test_estimated_stepping_batches.py index 320dd55692..203aaff65e 100644 --- a/tests/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/trainer/properties/test_estimated_stepping_batches.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from unittest import mock import pytest import torch @@ -148,7 +149,8 @@ def test_num_stepping_batches_with_tpu(devices, estimated_steps): assert trainer.estimated_stepping_batches == estimated_steps -def test_num_stepping_batches_with_ipu(monkeypatch): +@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) +def test_num_stepping_batches_with_ipu(mock_ipu_acc_avail, monkeypatch): """Test stepping batches with IPU training which acts like DP.""" import pytorch_lightning.strategies.ipu as ipu