add `accelerator.is_available()` check (#12104)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B <kaushikbokka@gmail.com>
This commit is contained in:
parent
5da065e287
commit
89d37569d8
|
@ -74,3 +74,8 @@ class Accelerator(ABC):
|
|||
@abstractmethod
|
||||
def is_available() -> bool:
|
||||
"""Detect if the hardware is available."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def name() -> str:
|
||||
"""Name of the Accelerator."""
|
||||
|
|
|
@ -62,3 +62,8 @@ class CPUAccelerator(Accelerator):
|
|||
def is_available() -> bool:
|
||||
"""CPU is always available for execution."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Name of the Accelerator."""
|
||||
return "cpu"
|
||||
|
|
|
@ -93,6 +93,11 @@ class GPUAccelerator(Accelerator):
|
|||
def is_available() -> bool:
|
||||
return torch.cuda.device_count() > 0
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Name of the Accelerator."""
|
||||
return "gpu"
|
||||
|
||||
|
||||
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]:
|
||||
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
|
||||
|
|
|
@ -46,3 +46,8 @@ class IPUAccelerator(Accelerator):
|
|||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return _IPU_AVAILABLE
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Name of the Accelerator."""
|
||||
return "ipu"
|
||||
|
|
|
@ -64,3 +64,8 @@ class TPUAccelerator(Accelerator):
|
|||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return _TPU_AVAILABLE
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Name of the Accelerator."""
|
||||
return "tpu"
|
||||
|
|
|
@ -465,16 +465,15 @@ class AcceleratorConnector:
|
|||
return "cpu"
|
||||
|
||||
def _set_parallel_devices_and_init_accelerator(self) -> None:
|
||||
# TODO add device availability check
|
||||
ACCELERATORS = {
|
||||
"cpu": CPUAccelerator,
|
||||
"gpu": GPUAccelerator,
|
||||
"tpu": TPUAccelerator,
|
||||
"ipu": IPUAccelerator,
|
||||
}
|
||||
if isinstance(self._accelerator_flag, Accelerator):
|
||||
self.accelerator: Accelerator = self._accelerator_flag
|
||||
else:
|
||||
ACCELERATORS = {
|
||||
"cpu": CPUAccelerator,
|
||||
"gpu": GPUAccelerator,
|
||||
"tpu": TPUAccelerator,
|
||||
"ipu": IPUAccelerator,
|
||||
}
|
||||
assert self._accelerator_flag is not None
|
||||
self._accelerator_flag = self._accelerator_flag.lower()
|
||||
if self._accelerator_flag not in ACCELERATORS:
|
||||
|
@ -485,6 +484,15 @@ class AcceleratorConnector:
|
|||
accelerator_class = ACCELERATORS[self._accelerator_flag]
|
||||
self.accelerator = accelerator_class() # type: ignore[abstract]
|
||||
|
||||
if not self.accelerator.is_available():
|
||||
available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()]
|
||||
raise MisconfigurationException(
|
||||
f"{self.accelerator.__class__.__qualname__} can not run on your system"
|
||||
f" since {self.accelerator.name().upper()}s are not available."
|
||||
" The following accelerator(s) is available and can be passed into"
|
||||
f" `accelerator` argument of `Trainer`: {available_accelerator}."
|
||||
)
|
||||
|
||||
self._set_devices_flag_if_auto_passed()
|
||||
|
||||
self._gpus = self._devices_flag if not self._gpus else self._gpus
|
||||
|
|
|
@ -358,6 +358,10 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock):
|
|||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "custom_acc_name"
|
||||
|
||||
class Prec(PrecisionPlugin):
|
||||
pass
|
||||
|
||||
|
@ -429,8 +433,9 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
|
|||
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
|
||||
def test_ipython_compatible_strategy_tpu(_, monkeypatch):
|
||||
def test_ipython_compatible_strategy_tpu(mock_devices, mock_tpu_acc_avail, monkeypatch):
|
||||
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
|
||||
trainer = Trainer(accelerator="tpu")
|
||||
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
|
||||
|
@ -480,9 +485,10 @@ def test_accelerator_cpu(_):
|
|||
|
||||
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
|
||||
trainer = Trainer(gpus=1)
|
||||
# TODO enable this test when add device availability check
|
||||
# with pytest.raises(MisconfigurationException, match="You requested gpu, but gpu is not available"):
|
||||
# trainer = Trainer(accelerator="gpu")
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available."
|
||||
):
|
||||
trainer = Trainer(accelerator="gpu")
|
||||
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
|
||||
trainer = Trainer(accelerator="cpu", gpus=1)
|
||||
|
||||
|
@ -899,8 +905,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
|
|||
assert trainer.strategy.local_rank == 0
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
|
||||
def test_unsupported_tpu_choice(mock_devices):
|
||||
def test_unsupported_tpu_choice(mock_devices, mock_tpu_acc_avail):
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
|
||||
Trainer(accelerator="tpu", precision=64)
|
||||
|
@ -915,7 +922,8 @@ def test_unsupported_tpu_choice(mock_devices):
|
|||
Trainer(accelerator="tpu", precision=16, amp_backend="apex", strategy="single_device")
|
||||
|
||||
|
||||
def test_unsupported_ipu_choice(monkeypatch):
|
||||
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
|
||||
def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch):
|
||||
import pytorch_lightning.strategies.ipu as ipu
|
||||
import pytorch_lightning.utilities.imports as imports
|
||||
|
||||
|
|
|
@ -44,6 +44,10 @@ def test_pluggable_accelerator():
|
|||
def is_available():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "custom_acc_name"
|
||||
|
||||
trainer = Trainer(accelerator=TestAccelerator(), devices=2, strategy="ddp")
|
||||
assert isinstance(trainer.accelerator, TestAccelerator)
|
||||
assert isinstance(trainer.strategy, DDPStrategy)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -97,7 +98,8 @@ class IPUClassificationModel(ClassificationModel):
|
|||
|
||||
|
||||
@pytest.mark.skipif(_IPU_AVAILABLE, reason="test requires non-IPU machine")
|
||||
def test_fail_if_no_ipus(tmpdir):
|
||||
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
|
||||
def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir):
|
||||
with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
|
||||
Trainer(default_root_dir=tmpdir, ipus=1)
|
||||
|
||||
|
|
|
@ -58,7 +58,8 @@ def environment_combinations():
|
|||
"strategy_cls",
|
||||
[DDPStrategy, DDPShardedStrategy, DDP2Strategy, pytest.param(DeepSpeedStrategy, marks=RunIf(deepspeed=True))],
|
||||
)
|
||||
def test_ranks_available_manual_strategy_selection(strategy_cls):
|
||||
@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True)
|
||||
def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strategy_cls):
|
||||
"""Test that the rank information is readily available after Trainer initialization."""
|
||||
num_nodes = 2
|
||||
for cluster, variables, expected in environment_combinations():
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -148,7 +149,8 @@ def test_num_stepping_batches_with_tpu(devices, estimated_steps):
|
|||
assert trainer.estimated_stepping_batches == estimated_steps
|
||||
|
||||
|
||||
def test_num_stepping_batches_with_ipu(monkeypatch):
|
||||
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
|
||||
def test_num_stepping_batches_with_ipu(mock_ipu_acc_avail, monkeypatch):
|
||||
"""Test stepping batches with IPU training which acts like DP."""
|
||||
import pytorch_lightning.strategies.ipu as ipu
|
||||
|
||||
|
|
Loading…
Reference in New Issue