Add support for empty `gpus` list to run on CPU (#10246)
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
facaff94b8
commit
cc0e9f96a8
|
@ -125,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
|
||||
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
|
||||
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
|
||||
- Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246))
|
||||
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
|
||||
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
|
|||
_check_data_type(gpus)
|
||||
|
||||
# Handle the case when no gpus are requested
|
||||
if gpus is None or isinstance(gpus, int) and gpus == 0 or str(gpus).strip() == "0":
|
||||
if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"):
|
||||
return None
|
||||
|
||||
# We know user requested GPUs therefore if some of the
|
||||
|
|
|
@ -182,6 +182,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu):
|
|||
[
|
||||
(None, None),
|
||||
(0, None),
|
||||
([], None),
|
||||
(1, [0]),
|
||||
(3, [0, 1, 2]),
|
||||
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
|
||||
|
@ -199,7 +200,7 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
|
|||
assert device_parser.parse_gpu_ids(gpus) == expected_gpu_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gpus", [0.1, -2, False, [], [-1], [None], ["0"], [0, 0]])
|
||||
@pytest.mark.parametrize("gpus", [0.1, -2, False, [-1], [None], ["0"], [0, 0]])
|
||||
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
|
||||
with pytest.raises(MisconfigurationException):
|
||||
device_parser.parse_gpu_ids(gpus)
|
||||
|
|
Loading…
Reference in New Issue