[CLI] Add support for `--key.help=class` (#10767)
This commit is contained in:
parent
24fc54f07b
commit
d3b7492bd0
|
@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed support for `--key.help=class` with the `LightningCLI` ([#10767](https://github.com/PyTorchLightning/pytorch-lightning/pull/10767))
|
||||
|
||||
|
||||
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))
|
||||
|
||||
|
||||
|
|
|
@ -263,9 +263,27 @@ class LightningArgumentParser(ArgumentParser):
|
|||
else:
|
||||
clean_argv.append(arg)
|
||||
i += 1
|
||||
|
||||
# the user requested a help message
|
||||
help_key = argv_key + ".help"
|
||||
if help_key in passed_args:
|
||||
argv_class = passed_args[help_key]
|
||||
if "." in argv_class:
|
||||
# user passed the class path directly
|
||||
class_path = argv_class
|
||||
else:
|
||||
# convert shorthand format to the classpath
|
||||
for cls in classes:
|
||||
if cls.__name__ == argv_class:
|
||||
class_path = _class_path_from_class(cls)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not generate get the class_path for {repr(argv_class)}")
|
||||
return clean_argv + [help_key, class_path]
|
||||
|
||||
# generate the associated config file
|
||||
argv_class = passed_args.pop(argv_key, None)
|
||||
if argv_class is None:
|
||||
argv_class = passed_args.pop(argv_key, "")
|
||||
if not argv_class:
|
||||
# the user passed a config as a str
|
||||
class_path = passed_args[f"{argv_key}.class_path"]
|
||||
init_args_key = f"{argv_key}.init_args"
|
||||
|
@ -769,12 +787,16 @@ class LightningCLI:
|
|||
return fn_kwargs
|
||||
|
||||
|
||||
def _class_path_from_class(class_type: Type) -> str:
|
||||
return class_type.__module__ + "." + class_type.__name__
|
||||
|
||||
|
||||
def _global_add_class_path(
|
||||
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(init_args, Namespace):
|
||||
init_args = init_args.as_dict()
|
||||
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}
|
||||
return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}}
|
||||
|
||||
|
||||
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
|
||||
|
|
|
@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
|
|||
torchtext>=0.8.*
|
||||
omegaconf>=2.0.5
|
||||
hydra-core>=1.0.5
|
||||
jsonargparse[signatures]>=4.0.0
|
||||
jsonargparse[signatures]>=4.0.4
|
||||
gcsfs>=2021.5.0
|
||||
rich>=10.2.2
|
||||
|
|
|
@ -57,7 +57,7 @@ if _TORCHVISION_AVAILABLE:
|
|||
|
||||
|
||||
@mock.patch("argparse.ArgumentParser.parse_args")
|
||||
def test_default_args(mock_argparse, tmpdir):
|
||||
def test_default_args(mock_argparse):
|
||||
"""Tests default argument parser for Trainer."""
|
||||
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
|
||||
|
||||
|
@ -870,7 +870,7 @@ class CustomCallback(Callback):
|
|||
pass
|
||||
|
||||
|
||||
def test_registries(tmpdir):
|
||||
def test_registries():
|
||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
||||
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
|
||||
|
@ -1360,9 +1360,27 @@ def test_lightning_cli_reinstantiate_trainer():
|
|||
assert cli.config_init["trainer"]["max_epochs"] is None
|
||||
|
||||
|
||||
def test_cli_configure_optimizers_warning(tmpdir):
|
||||
def test_cli_configure_optimizers_warning():
|
||||
match = "configure_optimizers` will be overridden by `LightningCLI"
|
||||
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
|
||||
|
||||
def test_cli_help_message():
|
||||
# full class path
|
||||
cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"]
|
||||
classpath_help = StringIO()
|
||||
with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
|
||||
cli_args = ["any.py", "--optimizer.help=Adam"]
|
||||
shorthand_help = StringIO()
|
||||
with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
|
||||
# the help messages should match
|
||||
assert shorthand_help.getvalue() == classpath_help.getvalue()
|
||||
# make sure it's not empty
|
||||
assert "Implements Adam" in shorthand_help.getvalue()
|
||||
|
|
Loading…
Reference in New Issue