test CLI parsing gpus (#2284)

* cli gpus

* test

* test
This commit is contained in:
Jirka Borovec 2020-06-20 05:41:42 +02:00 committed by GitHub
parent b96dd21d69
commit 7ecb0d2528
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 6 deletions

View File

@ -171,7 +171,7 @@ def test_root_gpu_property_0_passing(mocked_device_count_0, gpus, expected_root_
])
def test_root_gpu_property_0_raising(mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
with pytest.raises(MisconfigurationException):
Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu
Trainer(gpus=gpus, distributed_backend=distributed_backend)
@pytest.mark.gpus_param_tests

View File

@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace
from unittest import mock
import pytest
import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
@ -91,7 +92,6 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
parser.parse_args(cli_args)
# todo: add also testing for "gpus"
@pytest.mark.parametrize(['cli_args', 'expected'], [
pytest.param('--auto_lr_find --auto_scale_batch_size power',
{'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}),
@ -113,10 +113,25 @@ def test_argparse_args_parsing(cli_args, expected):
assert Trainer.from_argparse_args(args)
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
)
@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
pytest.param('--gpus 1', [0]),
pytest.param('--gpus 0,', [0]),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)
trainer = Trainer.from_argparse_args(args)
assert trainer.data_parallel_device_ids == expected_gpu
@pytest.mark.skipif(sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec")
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
pytest.param({}, {}),
pytest.param({'logger': False}, {}),