From cedaebfcbb2f0386d2e99f47bc9e92f717b3e42b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sat, 30 Oct 2021 02:01:32 +0530 Subject: [PATCH] Add `auto_device_count` method to `Accelerators` (#10222) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 6 ++++++ pytorch_lightning/accelerators/cpu.py | 5 +++++ pytorch_lightning/accelerators/gpu.py | 5 +++++ pytorch_lightning/accelerators/ipu.py | 7 +++++++ pytorch_lightning/accelerators/tpu.py | 5 +++++ tests/accelerators/test_common.py | 11 +++++++++++ 7 files changed, 42 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67516b8294..a0dbf92661 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) +- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222)) + + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 058f3fc3fb..6d5a83417b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from abc import abstractmethod from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch @@ -705,3 +706,8 @@ class Accelerator: "`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.on_train_batch_start(batch, batch_idx) + + @staticmethod + @abstractmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index d16e8b6a8b..8b18676eff 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -37,3 +37,8 @@ class CPUAccelerator(Accelerator): def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """CPU device stats aren't supported yet.""" return {} + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return 1 diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 44b29efe6f..b84e53c042 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -79,6 +79,11 @@ class GPUAccelerator(Accelerator): super().teardown() self._move_optimizer_state(torch.device("cpu")) + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return torch.cuda.device_count() + def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index ef38c9ed0b..0f6bdb8270 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -37,3 +37,10 @@ class IPUAccelerator(Accelerator): def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + # TODO (@kaushikb11): 4 is the minimal unit they are shipped in. + # Update this when api is exposed by the Graphcore team. + return 4 diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 7c7680c32f..6e824a25f6 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -73,3 +73,8 @@ class TPUAccelerator(Accelerator): "avg. peak memory (MB)": peak_memory, } return device_stats + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return 8 diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index d38c5c5fad..18bb04bd0a 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.utilities.seed import seed_everything from tests.accelerators.test_dp import CustomClassificationModelDP from tests.helpers.boring_model import BoringModel @@ -69,3 +72,11 @@ def test_model_parallel_setup_called(tmpdir): trainer.fit(model) assert model.configure_sharded_model_called + + +@mock.patch("torch.cuda.device_count", return_value=2) +def test_auto_device_count(device_count_mock): + assert CPUAccelerator.auto_device_count() == 1 + assert GPUAccelerator.auto_device_count() == 2 + assert TPUAccelerator.auto_device_count() == 8 + assert IPUAccelerator.auto_device_count() == 4