Add support for empty `gpus` list to run on CPU (#10246)

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
victorjoos 2021-11-01 19:37:38 +01:00 committed by GitHub
parent facaff94b8
commit cc0e9f96a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 2 deletions

View File

@ -125,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
- Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246))
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))

View File

@ -70,7 +70,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
_check_data_type(gpus)
# Handle the case when no gpus are requested
if gpus is None or isinstance(gpus, int) and gpus == 0 or str(gpus).strip() == "0":
if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"):
return None
# We know user requested GPUs therefore if some of the

View File

@ -182,6 +182,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu):
[
(None, None),
(0, None),
([], None),
(1, [0]),
(3, [0, 1, 2]),
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
@ -199,7 +200,7 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
assert device_parser.parse_gpu_ids(gpus) == expected_gpu_ids
@pytest.mark.parametrize("gpus", [0.1, -2, False, [], [-1], [None], ["0"], [0, 0]])
@pytest.mark.parametrize("gpus", [0.1, -2, False, [-1], [None], ["0"], [0, 0]])
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
with pytest.raises(MisconfigurationException):
device_parser.parse_gpu_ids(gpus)