Docs and Tests for "gpus" Trainer Argument (#593)

* add table for gpus argument

* fix typo in error message

* tests for supported values

* tests for unsupported values

* fix typo

* add table for gpus argument

* fix typo in error message

* tests for supported values

* tests for unsupported values

* fix typo

* fix typo list->str

* fix travis warning "line too long"
This commit is contained in:
Adrian Wälchli 2019-12-07 14:48:45 +01:00 committed by William Falcon
parent cc65f39d97
commit f7e1040236
2 changed files with 67 additions and 3 deletions

View File

@ -164,6 +164,38 @@ Make sure you're on a GPU machine. You can set as many GPUs as you want.
# RECOMMENDED use DistributedDataParallel
trainer = Trainer(gpus=8, distributed_backend='ddp')
Custom device selection
-----------------------
The number of GPUs can also be selected with a list of indices or a string containing
a comma separated list of GPU ids.
The table below lists examples of possible input formats and how they are interpreted by Lightning.
Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
+---------------+-----------+---------------------+---------------------------------+
| `gpus` | Type | Parsed | Meaning |
+===============+===========+=====================+=================================+
| None | NoneType | None | CPU |
+---------------+-----------+---------------------+---------------------------------+
| 0 | int | None | CPU |
+---------------+-----------+---------------------+---------------------------------+
| 3 | int | [0, 1, 2] | first 3 GPUs |
+---------------+-----------+---------------------+---------------------------------+
| -1 | int | [0, 1, 2, ...] | all available GPUs |
+---------------+-----------+---------------------+---------------------------------+
| [0] | list | [0] | GPU 0 |
+---------------+-----------+---------------------+---------------------------------+
| [1, 3] | list | [1, 3] | GPUs 1 and 3 |
+---------------+-----------+---------------------+---------------------------------+
| "0" | str | [0] | GPU 0 |
+---------------+-----------+---------------------+---------------------------------+
| "3" | str | [3] | GPU 3 |
+---------------+-----------+---------------------+---------------------------------+
| "1, 3" | str | [1, 3] | GPUs 1 and 3 |
+---------------+-----------+---------------------+---------------------------------+
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
+---------------+-----------+---------------------+---------------------------------+
Multi-node
----------
@ -531,7 +563,7 @@ def parse_gpu_ids(gpus):
gpus = sanitize_gpu_ids(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but non are available.")
raise MisconfigurationException("GPUs requested but none are available.")
return gpus

View File

@ -351,9 +351,14 @@ test_parse_gpu_ids_data = [
pytest.param(None, None),
pytest.param(0, None),
pytest.param(1, [0]),
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
pytest.param(3, [0, 1, 2]),
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
pytest.param([0], [0]),
pytest.param([1, 3], [1, 3]),
pytest.param('0', [0]),
pytest.param('3', [3]),
pytest.param('1, 3', [1, 3]),
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
]
@ -363,6 +368,33 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
assert parse_gpu_ids(gpus) == expected_gpu_ids
test_parse_gpu_invalid_inputs_data = [
pytest.param(0.1),
pytest.param(-2),
pytest.param(False),
pytest.param([]),
pytest.param([-1]),
pytest.param([None]),
pytest.param(['0']),
pytest.param((0, 1)),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data)
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [''])
def test_parse_gpu_fail_on_empty_string(mocked_device_count, gpus):
# This currently results in a ValueError instead of MisconfigurationException
with pytest.raises(ValueError):
parse_gpu_ids(gpus)
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1'])
def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus):