[CLI] Add support for `--key.help=class` (#10767)

This commit is contained in:
Carlos Mocholí 2021-11-29 15:12:53 +01:00 committed by GitHub
parent 24fc54f07b
commit d3b7492bd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 7 deletions

View File

@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed support for `--key.help=class` with the `LightningCLI` ([#10767](https://github.com/PyTorchLightning/pytorch-lightning/pull/10767))
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))

View File

@ -263,9 +263,27 @@ class LightningArgumentParser(ArgumentParser):
else:
clean_argv.append(arg)
i += 1
# the user requested a help message
help_key = argv_key + ".help"
if help_key in passed_args:
argv_class = passed_args[help_key]
if "." in argv_class:
# user passed the class path directly
class_path = argv_class
else:
# convert shorthand format to the classpath
for cls in classes:
if cls.__name__ == argv_class:
class_path = _class_path_from_class(cls)
break
else:
raise ValueError(f"Could not generate get the class_path for {repr(argv_class)}")
return clean_argv + [help_key, class_path]
# generate the associated config file
argv_class = passed_args.pop(argv_key, None)
if argv_class is None:
argv_class = passed_args.pop(argv_key, "")
if not argv_class:
# the user passed a config as a str
class_path = passed_args[f"{argv_key}.class_path"]
init_args_key = f"{argv_key}.init_args"
@ -769,12 +787,16 @@ class LightningCLI:
return fn_kwargs
def _class_path_from_class(class_type: Type) -> str:
return class_type.__module__ + "." + class_type.__name__
def _global_add_class_path(
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
) -> Dict[str, Any]:
if isinstance(init_args, Namespace):
init_args = init_args.as_dict()
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}
return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}}
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:

View File

@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
torchtext>=0.8.*
omegaconf>=2.0.5
hydra-core>=1.0.5
jsonargparse[signatures]>=4.0.0
jsonargparse[signatures]>=4.0.4
gcsfs>=2021.5.0
rich>=10.2.2

View File

@ -57,7 +57,7 @@ if _TORCHVISION_AVAILABLE:
@mock.patch("argparse.ArgumentParser.parse_args")
def test_default_args(mock_argparse, tmpdir):
def test_default_args(mock_argparse):
"""Tests default argument parser for Trainer."""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
@ -870,7 +870,7 @@ class CustomCallback(Callback):
pass
def test_registries(tmpdir):
def test_registries():
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
@ -1360,9 +1360,27 @@ def test_lightning_cli_reinstantiate_trainer():
assert cli.config_init["trainer"]["max_epochs"] is None
def test_cli_configure_optimizers_warning(tmpdir):
def test_cli_configure_optimizers_warning():
match = "configure_optimizers` will be overridden by `LightningCLI"
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
def test_cli_help_message():
# full class path
cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"]
classpath_help = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
cli_args = ["any.py", "--optimizer.help=Adam"]
shorthand_help = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
# the help messages should match
assert shorthand_help.getvalue() == classpath_help.getvalue()
# make sure it's not empty
assert "Implements Adam" in shorthand_help.getvalue()