Typing `tuner.auto_gpu_select` (#9292)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
jjenniferdai 2021-09-03 07:49:58 -07:00 committed by GitHub
parent d5ee8d8e3f
commit e97c28a02b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -66,6 +66,7 @@ module = [
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator",
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
"pytorch_lightning.tuner.auto_gpu_select",
"pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",

View File

@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def pick_multiple_gpus(nb):
def pick_multiple_gpus(nb: int) -> List[int]:
"""
Raises:
MisconfigurationException:
@ -30,14 +32,14 @@ def pick_multiple_gpus(nb):
nb = torch.cuda.device_count() if nb == -1 else nb
picked = []
picked: List[int] = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))
return picked
def pick_single_gpu(exclude_gpus: list):
def pick_single_gpu(exclude_gpus: List[int]) -> int:
"""
Raises:
RuntimeError: