Add check for unique device ids (#8666)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
Isaac 2021-08-03 16:18:51 +08:00 committed by GitHub
parent e5d9e21dea
commit 8274183bf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 1 deletions

View File

@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666))
- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641))

View File

@ -57,7 +57,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
Args:
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
A list of ints or a string containing list of comma separated integers
A list of unique ints or a string containing list of comma separated unique integers
indicates specific GPUs to use.
An int 0 means that no GPUs should be used.
Any int N > 0 indicates that GPUs [0..N) should be used.
@ -88,6 +88,10 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
# omit sanity check on torchelastic as by default shows one visible GPU per process
return gpus
# Check that gpus are unique. Duplicate gpus are not supported by the backend.
_check_unique(gpus)
return _sanitize_gpu_ids(gpus)
@ -188,6 +192,21 @@ def _get_all_available_gpus() -> List[int]:
return list(range(torch.cuda.device_count()))
def _check_unique(device_ids: List[int]) -> None:
"""
Checks that the device_ids are unique.
Args:
device_ids: list of ints corresponding to gpus indices
Raises:
MisconfigurationException:
If ``device_ids`` of GPUs aren't unique
"""
if len(device_ids) != len(set(device_ids)):
raise MisconfigurationException("Device ID's (GPU) must be unique.")
def _check_data_type(device_ids: Any) -> None:
"""
Checks that the device_ids argument is one of: None, Int, String or List.

View File

@ -217,6 +217,7 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
pytest.param([-1]),
pytest.param([None]),
pytest.param(["0"]),
pytest.param([0, 0]),
],
)
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):