parent
b96dd21d69
commit
7ecb0d2528
|
@ -171,7 +171,7 @@ def test_root_gpu_property_0_passing(mocked_device_count_0, gpus, expected_root_
|
|||
])
|
||||
def test_root_gpu_property_0_raising(mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
|
||||
with pytest.raises(MisconfigurationException):
|
||||
Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu
|
||||
Trainer(gpus=gpus, distributed_backend=distributed_backend)
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
|
|
|
@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -91,7 +92,6 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
|||
parser.parse_args(cli_args)
|
||||
|
||||
|
||||
# todo: add also testing for "gpus"
|
||||
@pytest.mark.parametrize(['cli_args', 'expected'], [
|
||||
pytest.param('--auto_lr_find --auto_scale_batch_size power',
|
||||
{'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}),
|
||||
|
@ -113,10 +113,25 @@ def test_argparse_args_parsing(cli_args, expected):
|
|||
assert Trainer.from_argparse_args(args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 7),
|
||||
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
|
||||
)
|
||||
@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
|
||||
pytest.param('--gpus 1', [0]),
|
||||
pytest.param('--gpus 0,', [0]),
|
||||
])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
|
||||
"""Test multi type argument with bool."""
|
||||
cli_args = cli_args.split(' ') if cli_args else []
|
||||
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
args = Trainer.parse_argparser(parser)
|
||||
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
assert trainer.data_parallel_device_ids == expected_gpu
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 7),
|
||||
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec")
|
||||
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
|
||||
pytest.param({}, {}),
|
||||
pytest.param({'logger': False}, {}),
|
||||
|
|
Loading…
Reference in New Issue