Add `auto_device_count` method to `Accelerators` (#10222)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
848ad3f41d
commit
cedaebfcbb
|
@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
|
||||
|
||||
|
||||
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -705,3 +706,8 @@ class Accelerator:
|
|||
"`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations."
|
||||
)
|
||||
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
|
|
|
@ -37,3 +37,8 @@ class CPUAccelerator(Accelerator):
|
|||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""CPU device stats aren't supported yet."""
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
return 1
|
||||
|
|
|
@ -79,6 +79,11 @@ class GPUAccelerator(Accelerator):
|
|||
super().teardown()
|
||||
self._move_optimizer_state(torch.device("cpu"))
|
||||
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
||||
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
|
||||
|
|
|
@ -37,3 +37,10 @@ class IPUAccelerator(Accelerator):
|
|||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""IPU device stats aren't supported yet."""
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
# TODO (@kaushikb11): 4 is the minimal unit they are shipped in.
|
||||
# Update this when api is exposed by the Graphcore team.
|
||||
return 4
|
||||
|
|
|
@ -73,3 +73,8 @@ class TPUAccelerator(Accelerator):
|
|||
"avg. peak memory (MB)": peak_memory,
|
||||
}
|
||||
return device_stats
|
||||
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
return 8
|
||||
|
|
|
@ -11,11 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from tests.accelerators.test_dp import CustomClassificationModelDP
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
@ -69,3 +72,11 @@ def test_model_parallel_setup_called(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
assert model.configure_sharded_model_called
|
||||
|
||||
|
||||
@mock.patch("torch.cuda.device_count", return_value=2)
|
||||
def test_auto_device_count(device_count_mock):
|
||||
assert CPUAccelerator.auto_device_count() == 1
|
||||
assert GPUAccelerator.auto_device_count() == 2
|
||||
assert TPUAccelerator.auto_device_count() == 8
|
||||
assert IPUAccelerator.auto_device_count() == 4
|
||||
|
|
Loading…
Reference in New Issue