Docs and Tests for "gpus" Trainer Argument (#593)
* add table for gpus argument * fix typo in error message * tests for supported values * tests for unsupported values * fix typo * add table for gpus argument * fix typo in error message * tests for supported values * tests for unsupported values * fix typo * fix typo list->str * fix travis warning "line too long"
This commit is contained in:
parent
cc65f39d97
commit
f7e1040236
|
@ -164,6 +164,38 @@ Make sure you're on a GPU machine. You can set as many GPUs as you want.
|
|||
# RECOMMENDED use DistributedDataParallel
|
||||
trainer = Trainer(gpus=8, distributed_backend='ddp')
|
||||
|
||||
Custom device selection
|
||||
-----------------------
|
||||
|
||||
The number of GPUs can also be selected with a list of indices or a string containing
|
||||
a comma separated list of GPU ids.
|
||||
The table below lists examples of possible input formats and how they are interpreted by Lightning.
|
||||
Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
|
||||
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| `gpus` | Type | Parsed | Meaning |
|
||||
+===============+===========+=====================+=================================+
|
||||
| None | NoneType | None | CPU |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| 0 | int | None | CPU |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| 3 | int | [0, 1, 2] | first 3 GPUs |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| -1 | int | [0, 1, 2, ...] | all available GPUs |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| [0] | list | [0] | GPU 0 |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| [1, 3] | list | [1, 3] | GPUs 1 and 3 |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| "0" | str | [0] | GPU 0 |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| "3" | str | [3] | GPU 3 |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| "1, 3" | str | [1, 3] | GPUs 1 and 3 |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
|
||||
+---------------+-----------+---------------------+---------------------------------+
|
||||
|
||||
|
||||
Multi-node
|
||||
----------
|
||||
|
@ -531,7 +563,7 @@ def parse_gpu_ids(gpus):
|
|||
gpus = sanitize_gpu_ids(gpus)
|
||||
|
||||
if not gpus:
|
||||
raise MisconfigurationException("GPUs requested but non are available.")
|
||||
raise MisconfigurationException("GPUs requested but none are available.")
|
||||
return gpus
|
||||
|
||||
|
||||
|
|
|
@ -351,9 +351,14 @@ test_parse_gpu_ids_data = [
|
|||
pytest.param(None, None),
|
||||
pytest.param(0, None),
|
||||
pytest.param(1, [0]),
|
||||
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
|
||||
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
|
||||
pytest.param(3, [0, 1, 2]),
|
||||
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
|
||||
pytest.param([0], [0]),
|
||||
pytest.param([1, 3], [1, 3]),
|
||||
pytest.param('0', [0]),
|
||||
pytest.param('3', [3]),
|
||||
pytest.param('1, 3', [1, 3]),
|
||||
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -363,6 +368,33 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
|
|||
assert parse_gpu_ids(gpus) == expected_gpu_ids
|
||||
|
||||
|
||||
test_parse_gpu_invalid_inputs_data = [
|
||||
pytest.param(0.1),
|
||||
pytest.param(-2),
|
||||
pytest.param(False),
|
||||
pytest.param([]),
|
||||
pytest.param([-1]),
|
||||
pytest.param([None]),
|
||||
pytest.param(['0']),
|
||||
pytest.param((0, 1)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data)
|
||||
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
|
||||
with pytest.raises(MisconfigurationException):
|
||||
parse_gpu_ids(gpus)
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize("gpus", [''])
|
||||
def test_parse_gpu_fail_on_empty_string(mocked_device_count, gpus):
|
||||
# This currently results in a ValueError instead of MisconfigurationException
|
||||
with pytest.raises(ValueError):
|
||||
parse_gpu_ids(gpus)
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1'])
|
||||
def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus):
|
||||
|
|
Loading…
Reference in New Issue