fix gpus default for Trainer.add_argparse_args (#6898)
This commit is contained in:
parent
aaccbeea2b
commit
9c9e2a0325
|
@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
|
||||
|
||||
|
||||
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -232,7 +232,6 @@ def add_argparse_args(
|
|||
|
||||
if arg == 'gpus' or arg == 'tpu_cores':
|
||||
use_type = _gpus_allowed_type
|
||||
arg_default = _gpus_arg_default
|
||||
|
||||
# hack for types in (int, float)
|
||||
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
|
||||
|
@ -287,10 +286,6 @@ def _gpus_allowed_type(x) -> Union[int, str]:
|
|||
return int(x)
|
||||
|
||||
|
||||
def _gpus_arg_default(x) -> Union[int, str]:
|
||||
return _gpus_allowed_type(x)
|
||||
|
||||
|
||||
def _int_or_float_type(x) -> Union[int, float]:
|
||||
if '.' in str(x):
|
||||
return float(x)
|
||||
|
|
|
@ -59,11 +59,6 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
|
|||
If no GPUs are available but the value of gpus variable indicates request for GPUs
|
||||
then a MisconfigurationException is raised.
|
||||
"""
|
||||
|
||||
# nothing was passed into the GPUs argument
|
||||
if callable(gpus):
|
||||
return None
|
||||
|
||||
# Check that gpus param is None, Int, String or List
|
||||
_check_data_type(gpus)
|
||||
|
||||
|
@ -97,10 +92,6 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
|
|||
Returns:
|
||||
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
|
||||
"""
|
||||
|
||||
if callable(tpu_cores):
|
||||
return None
|
||||
|
||||
_check_data_type(tpu_cores)
|
||||
|
||||
if isinstance(tpu_cores, str):
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
import types
|
||||
from argparse import ArgumentParser
|
||||
from unittest import mock
|
||||
|
||||
|
@ -172,11 +171,10 @@ def test_wandb_sanitize_callable_params(tmpdir):
|
|||
params.wrapper_something_wo_name = lambda: lambda: '1'
|
||||
params.wrapper_something = wrapper_something
|
||||
|
||||
assert isinstance(params.gpus, types.FunctionType)
|
||||
params = WandbLogger._convert_params(params)
|
||||
params = WandbLogger._flatten_dict(params)
|
||||
params = WandbLogger._sanitize_callable_params(params)
|
||||
assert params["gpus"] == '_gpus_arg_default'
|
||||
assert params["gpus"] == "None"
|
||||
assert params["something"] == "something"
|
||||
assert params["wrapper_something"] == "wrapper_something"
|
||||
assert params["wrapper_something_wo_name"] == "<lambda>"
|
||||
|
|
|
@ -175,12 +175,13 @@ def test_argparse_args_parsing(cli_args, expected):
|
|||
assert Trainer.from_argparse_args(args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
|
||||
pytest.param('--gpus 1', [0]),
|
||||
pytest.param('--gpus 0,', [0]),
|
||||
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
|
||||
pytest.param('', None, None),
|
||||
pytest.param('--gpus 1', 1, [0]),
|
||||
pytest.param('--gpus 0,', '0,', [0]),
|
||||
])
|
||||
@RunIf(min_gpus=1)
|
||||
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
|
||||
def test_argparse_args_parsing_gpus(cli_args, expected_parsed, expected_device_ids):
|
||||
"""Test multi type argument with bool."""
|
||||
cli_args = cli_args.split(' ') if cli_args else []
|
||||
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
|
||||
|
@ -188,8 +189,9 @@ def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
|
|||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
args = Trainer.parse_argparser(parser)
|
||||
|
||||
assert args.gpus == expected_parsed
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
assert trainer.data_parallel_device_ids == expected_gpu
|
||||
assert trainer.data_parallel_device_ids == expected_device_ids
|
||||
|
||||
|
||||
@RunIf(min_python="3.7.0")
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
_gpus_arg_default,
|
||||
_gpus_allowed_type,
|
||||
_int_or_float_type,
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
|
@ -205,9 +205,9 @@ def test_add_argparse_args_no_argument_group():
|
|||
assert args.my_parameter == 2
|
||||
|
||||
|
||||
def test_gpus_arg_default():
|
||||
assert _gpus_arg_default('1,2') == '1,2'
|
||||
assert _gpus_arg_default('1') == 1
|
||||
def test_gpus_allowed_type():
|
||||
assert _gpus_allowed_type('1,2') == '1,2'
|
||||
assert _gpus_allowed_type('1') == 1
|
||||
|
||||
|
||||
def test_int_or_float_type():
|
||||
|
|
Loading…
Reference in New Issue