fix gpus default for Trainer.add_argparse_args (#6898)

This commit is contained in:
Adrian Wälchli 2021-04-09 11:20:43 +02:00 committed by GitHub
parent aaccbeea2b
commit 9c9e2a0325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 15 additions and 26 deletions

View File

@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -232,7 +232,6 @@ def add_argparse_args(
if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default
# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
@ -287,10 +286,6 @@ def _gpus_allowed_type(x) -> Union[int, str]:
return int(x)
def _gpus_arg_default(x) -> Union[int, str]:
return _gpus_allowed_type(x)
def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)

View File

@ -59,11 +59,6 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""
# nothing was passed into the GPUs argument
if callable(gpus):
return None
# Check that gpus param is None, Int, String or List
_check_data_type(gpus)
@ -97,10 +92,6 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""
if callable(tpu_cores):
return None
_check_data_type(tpu_cores)
if isinstance(tpu_cores, str):

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import pickle
import types
from argparse import ArgumentParser
from unittest import mock
@ -172,11 +171,10 @@ def test_wandb_sanitize_callable_params(tmpdir):
params.wrapper_something_wo_name = lambda: lambda: '1'
params.wrapper_something = wrapper_something
assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["gpus"] == "None"
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"

View File

@ -175,12 +175,13 @@ def test_argparse_args_parsing(cli_args, expected):
assert Trainer.from_argparse_args(args)
@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
pytest.param('--gpus 1', [0]),
pytest.param('--gpus 0,', [0]),
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
pytest.param('', None, None),
pytest.param('--gpus 1', 1, [0]),
pytest.param('--gpus 0,', '0,', [0]),
])
@RunIf(min_gpus=1)
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
def test_argparse_args_parsing_gpus(cli_args, expected_parsed, expected_device_ids):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
@ -188,8 +189,9 @@ def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)
assert args.gpus == expected_parsed
trainer = Trainer.from_argparse_args(args)
assert trainer.data_parallel_device_ids == expected_gpu
assert trainer.data_parallel_device_ids == expected_device_ids
@RunIf(min_python="3.7.0")

View File

@ -7,7 +7,7 @@ import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import (
_gpus_arg_default,
_gpus_allowed_type,
_int_or_float_type,
add_argparse_args,
from_argparse_args,
@ -205,9 +205,9 @@ def test_add_argparse_args_no_argument_group():
assert args.my_parameter == 2
def test_gpus_arg_default():
assert _gpus_arg_default('1,2') == '1,2'
assert _gpus_arg_default('1') == 1
def test_gpus_allowed_type():
assert _gpus_allowed_type('1,2') == '1,2'
assert _gpus_allowed_type('1') == 1
def test_int_or_float_type():