fix: `nb` is set total number of devices, when nb is -1. (#4209)

* fix: `nb` is set total number of devices, when nb is -1.

 Refs: #4207

* feat: add test code
     1. test combination `auto_select_gpus`, `gpus` options using
Trainer
     2. test `pick_multiple_gpus` function directly

Refs: #4207

* docs: modify contents in `Select GPU devices`

 Refs: #4207

* refactore: reflect the reuslt of review

 Refs: #4207

* refactore: reflect the reuslt of review

 Refs: #4207

* Update CHANGELOG.md

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Roger Shieh <55400948+s-rog@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Martin Hwang 2020-10-29 18:50:37 +09:00 committed by GitHub
parent ce261e4afe
commit b459fd26ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 0 deletions

View File

@ -66,12 +66,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed ### Fixed
- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))
- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297)) - Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297))
- Fixed synchronization of best model path in `ddp_accelerator` ([#4323](https://github.com/PyTorchLightning/pytorch-lightning/pull/4323)) - Fixed synchronization of best model path in `ddp_accelerator` ([#4323](https://github.com/PyTorchLightning/pytorch-lightning/pull/4323))
- Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) - Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341))
## [1.0.3] - 2020-10-20 ## [1.0.3] - 2020-10-20
### Added ### Added

View File

@ -206,6 +206,8 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
`auto_select_gpus=True` will automatically help you find `k` gpus that are not `auto_select_gpus=True` will automatically help you find `k` gpus that are not
occupied by other processes. This is especially useful when GPUs are configured occupied by other processes. This is especially useful when GPUs are configured
to be in "exclusive mode", such that only one process at a time can access them. to be in "exclusive mode", such that only one process at a time can access them.
For more details see the :ref:`Trainer guide <trainer>`.
Remove CUDA flags Remove CUDA flags
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^

View File

@ -381,6 +381,12 @@ Example::
# enable auto selection (will find two available gpus on system) # enable auto selection (will find two available gpus on system)
trainer = Trainer(gpus=2, auto_select_gpus=True) trainer = Trainer(gpus=2, auto_select_gpus=True)
# specifies all GPUs regardless of its availability
Trainer(gpus=-1, auto_select_gpus=False)
# specifies all available GPUs (if only one GPU is not occupied, uses one gpu)
Trainer(gpus=-1, auto_select_gpus=True)
auto_lr_find auto_lr_find
^^^^^^^^^^^^ ^^^^^^^^^^^^

View File

@ -13,8 +13,18 @@
# limitations under the License. # limitations under the License.
import torch import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def pick_multiple_gpus(nb): def pick_multiple_gpus(nb):
if nb == 0:
raise MisconfigurationException(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
)
nb = torch.cuda.device_count() if nb == -1 else nb
picked = [] picked = []
for _ in range(nb): for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked)) picked.append(pick_single_gpu(exclude_gpus=picked))

0
tests/tuner/__init__.py Normal file
View File

View File

@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1"
)
@pytest.mark.parametrize(
["auto_select_gpus", "gpus", "expected_error"],
[
(True, 0, MisconfigurationException),
(True, -1, None),
(False, 0, None),
(False, -1, None),
],
)
def test_trainer_with_gpus_options_combination_at_available_gpus_env(
auto_select_gpus, gpus, expected_error
):
if expected_error:
with pytest.raises(
expected_error,
match=re.escape(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
),
):
trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus)
else:
trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1"
)
@pytest.mark.parametrize(
["nb", "expected_gpu_idxs", "expected_error"],
[
(0, [], MisconfigurationException),
(-1, [i for i in range(torch.cuda.device_count())], None),
(1, [0], None),
],
)
def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error):
if expected_error:
with pytest.raises(
expected_error,
match=re.escape(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
),
):
pick_multiple_gpus(nb)
else:
assert expected_gpu_idxs == pick_multiple_gpus(nb)