From b459fd26ac773484e4c97c12e4bab221bb1609b0 Mon Sep 17 00:00:00 2001 From: Martin Hwang Date: Thu, 29 Oct 2020 18:50:37 +0900 Subject: [PATCH] fix: `nb` is set total number of devices, when nb is -1. (#4209) * fix: `nb` is set total number of devices, when nb is -1. Refs: #4207 * feat: add test code 1. test combination `auto_select_gpus`, `gpus` options using Trainer 2. test `pick_multiple_gpus` function directly Refs: #4207 * docs: modify contents in `Select GPU devices` Refs: #4207 * refactore: reflect the reuslt of review Refs: #4207 * refactore: reflect the reuslt of review Refs: #4207 * Update CHANGELOG.md Co-authored-by: chaton Co-authored-by: Roger Shieh <55400948+s-rog@users.noreply.github.com> Co-authored-by: Nicki Skafte --- CHANGELOG.md | 7 ++ docs/source/multi_gpu.rst | 2 + pytorch_lightning/trainer/__init__.py | 6 ++ pytorch_lightning/tuner/auto_gpu_select.py | 10 +++ tests/tuner/__init__.py | 0 tests/tuner/test_auto_gpu_select.py | 74 ++++++++++++++++++++++ 6 files changed, 99 insertions(+) create mode 100644 tests/tuner/__init__.py create mode 100644 tests/tuner/test_auto_gpu_select.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1de62b442f..803ece1a51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,12 +66,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + +- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) + + - Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297)) + - Fixed synchronization of best model path in `ddp_accelerator` ([#4323](https://github.com/PyTorchLightning/pytorch-lightning/pull/4323)) + - Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) + ## [1.0.3] - 2020-10-20 ### Added diff --git a/docs/source/multi_gpu.rst b/docs/source/multi_gpu.rst index ea49601a39..8ea8646e13 100644 --- a/docs/source/multi_gpu.rst +++ b/docs/source/multi_gpu.rst @@ -206,6 +206,8 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`. `auto_select_gpus=True` will automatically help you find `k` gpus that are not occupied by other processes. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. + For more details see the :ref:`Trainer guide `. + Remove CUDA flags ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index a4bf2969f4..954befb00a 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -381,6 +381,12 @@ Example:: # enable auto selection (will find two available gpus on system) trainer = Trainer(gpus=2, auto_select_gpus=True) + # specifies all GPUs regardless of its availability + Trainer(gpus=-1, auto_select_gpus=False) + + # specifies all available GPUs (if only one GPU is not occupied, uses one gpu) + Trainer(gpus=-1, auto_select_gpus=True) + auto_lr_find ^^^^^^^^^^^^ diff --git a/pytorch_lightning/tuner/auto_gpu_select.py b/pytorch_lightning/tuner/auto_gpu_select.py index f1b13a6974..fd2ba4a1f3 100644 --- a/pytorch_lightning/tuner/auto_gpu_select.py +++ b/pytorch_lightning/tuner/auto_gpu_select.py @@ -13,8 +13,18 @@ # limitations under the License. import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException + def pick_multiple_gpus(nb): + if nb == 0: + raise MisconfigurationException( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ) + + nb = torch.cuda.device_count() if nb == -1 else nb + picked = [] for _ in range(nb): picked.append(pick_single_gpu(exclude_gpus=picked)) diff --git a/tests/tuner/__init__.py b/tests/tuner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tuner/test_auto_gpu_select.py b/tests/tuner/test_auto_gpu_select.py new file mode 100644 index 0000000000..36b33a707b --- /dev/null +++ b/tests/tuner/test_auto_gpu_select.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1" +) +@pytest.mark.parametrize( + ["auto_select_gpus", "gpus", "expected_error"], + [ + (True, 0, MisconfigurationException), + (True, -1, None), + (False, 0, None), + (False, -1, None), + ], +) +def test_trainer_with_gpus_options_combination_at_available_gpus_env( + auto_select_gpus, gpus, expected_error +): + if expected_error: + with pytest.raises( + expected_error, + match=re.escape( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ), + ): + trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus) + else: + trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1" +) +@pytest.mark.parametrize( + ["nb", "expected_gpu_idxs", "expected_error"], + [ + (0, [], MisconfigurationException), + (-1, [i for i in range(torch.cuda.device_count())], None), + (1, [0], None), + ], +) +def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error): + if expected_error: + with pytest.raises( + expected_error, + match=re.escape( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ), + ): + pick_multiple_gpus(nb) + else: + assert expected_gpu_idxs == pick_multiple_gpus(nb)