Add `auto_device_count` method to `Accelerators` (#10222)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Kaushik B 2021-10-30 02:01:32 +05:30 committed by GitHub
parent 848ad3f41d
commit cedaebfcbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 0 deletions

View File

@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
### Changed
- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
import torch
@ -705,3 +706,8 @@ class Accelerator:
"`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations."
)
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)
@staticmethod
@abstractmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""

View File

@ -37,3 +37,8 @@ class CPUAccelerator(Accelerator):
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""CPU device stats aren't supported yet."""
return {}
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

View File

@ -79,6 +79,11 @@ class GPUAccelerator(Accelerator):
super().teardown()
self._move_optimizer_state(torch.device("cpu"))
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return torch.cuda.device_count()
def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

View File

@ -37,3 +37,10 @@ class IPUAccelerator(Accelerator):
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""IPU device stats aren't supported yet."""
return {}
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
# TODO (@kaushikb11): 4 is the minimal unit they are shipped in.
# Update this when api is exposed by the Graphcore team.
return 4

View File

@ -73,3 +73,8 @@ class TPUAccelerator(Accelerator):
"avg. peak memory (MB)": peak_memory,
}
return device_stats
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 8

View File

@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import pytest
import torch
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
from pytorch_lightning.utilities.seed import seed_everything
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
@ -69,3 +72,11 @@ def test_model_parallel_setup_called(tmpdir):
trainer.fit(model)
assert model.configure_sharded_model_called
@mock.patch("torch.cuda.device_count", return_value=2)
def test_auto_device_count(device_count_mock):
assert CPUAccelerator.auto_device_count() == 1
assert GPUAccelerator.auto_device_count() == 2
assert TPUAccelerator.auto_device_count() == 8
assert IPUAccelerator.auto_device_count() == 4