From 7ecb0d25281ef1320959d036c8020231029106b2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 20 Jun 2020 05:41:42 +0200 Subject: [PATCH] test CLI parsing gpus (#2284) * cli gpus * test * test --- tests/models/test_gpu.py | 2 +- tests/trainer/test_trainer_cli.py | 25 ++++++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index cdfbd9b09c..b597af4fe8 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -171,7 +171,7 @@ def test_root_gpu_property_0_passing(mocked_device_count_0, gpus, expected_root_ ]) def test_root_gpu_property_0_raising(mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): with pytest.raises(MisconfigurationException): - Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu + Trainer(gpus=gpus, distributed_backend=distributed_backend) @pytest.mark.gpus_param_tests diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index c66d614903..2bd07560b1 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace from unittest import mock import pytest +import torch import tests.base.utils as tutils from pytorch_lightning import Trainer @@ -91,7 +92,6 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch): parser.parse_args(cli_args) -# todo: add also testing for "gpus" @pytest.mark.parametrize(['cli_args', 'expected'], [ pytest.param('--auto_lr_find --auto_scale_batch_size power', {'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}), @@ -113,10 +113,25 @@ def test_argparse_args_parsing(cli_args, expected): assert Trainer.from_argparse_args(args) -@pytest.mark.skipif( - sys.version_info < (3, 7), - reason="signature inspection while mocking is not working in Python < 3.7 despite autospec" -) +@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [ + pytest.param('--gpus 1', [0]), + pytest.param('--gpus 0,', [0]), +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_argparse_args_parsing_gpus(cli_args, expected_gpu): + """Test multi type argument with bool.""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + trainer = Trainer.from_argparse_args(args) + assert trainer.data_parallel_device_ids == expected_gpu + + +@pytest.mark.skipif(sys.version_info < (3, 7), + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec") @pytest.mark.parametrize(['cli_args', 'extra_args'], [ pytest.param({}, {}), pytest.param({'logger': False}, {}),