diff --git a/CHANGELOG.md b/CHANGELOG.md index bb5d26a107..08c0fc05c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed support for `--key.help=class` with the `LightningCLI` ([#10767](https://github.com/PyTorchLightning/pytorch-lightning/pull/10767)) + + - Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index b08ad7265c..601772ef9d 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -263,9 +263,27 @@ class LightningArgumentParser(ArgumentParser): else: clean_argv.append(arg) i += 1 + + # the user requested a help message + help_key = argv_key + ".help" + if help_key in passed_args: + argv_class = passed_args[help_key] + if "." in argv_class: + # user passed the class path directly + class_path = argv_class + else: + # convert shorthand format to the classpath + for cls in classes: + if cls.__name__ == argv_class: + class_path = _class_path_from_class(cls) + break + else: + raise ValueError(f"Could not generate get the class_path for {repr(argv_class)}") + return clean_argv + [help_key, class_path] + # generate the associated config file - argv_class = passed_args.pop(argv_key, None) - if argv_class is None: + argv_class = passed_args.pop(argv_key, "") + if not argv_class: # the user passed a config as a str class_path = passed_args[f"{argv_key}.class_path"] init_args_key = f"{argv_key}.init_args" @@ -769,12 +787,16 @@ class LightningCLI: return fn_kwargs +def _class_path_from_class(class_type: Type) -> str: + return class_type.__module__ + "." + class_type.__name__ + + def _global_add_class_path( class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None ) -> Dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() - return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}} + return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: diff --git a/requirements/extra.txt b/requirements/extra.txt index 6abf3089b8..babaffca62 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta torchtext>=0.8.* omegaconf>=2.0.5 hydra-core>=1.0.5 -jsonargparse[signatures]>=4.0.0 +jsonargparse[signatures]>=4.0.4 gcsfs>=2021.5.0 rich>=10.2.2 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 1d6146f16e..7b6d52c7cd 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -57,7 +57,7 @@ if _TORCHVISION_AVAILABLE: @mock.patch("argparse.ArgumentParser.parse_args") -def test_default_args(mock_argparse, tmpdir): +def test_default_args(mock_argparse): """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) @@ -870,7 +870,7 @@ class CustomCallback(Callback): pass -def test_registries(tmpdir): +def test_registries(): assert "SGD" in OPTIMIZER_REGISTRY.names assert "RMSprop" in OPTIMIZER_REGISTRY.names assert "CustomAdam" in OPTIMIZER_REGISTRY.names @@ -1360,9 +1360,27 @@ def test_lightning_cli_reinstantiate_trainer(): assert cli.config_init["trainer"]["max_epochs"] is None -def test_cli_configure_optimizers_warning(tmpdir): +def test_cli_configure_optimizers_warning(): match = "configure_optimizers` will be overridden by `LightningCLI" with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match): LightningCLI(BoringModel, run=False) with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match): LightningCLI(BoringModel, run=False) + + +def test_cli_help_message(): + # full class path + cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"] + classpath_help = StringIO() + with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit): + LightningCLI(BoringModel, run=False) + + cli_args = ["any.py", "--optimizer.help=Adam"] + shorthand_help = StringIO() + with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit): + LightningCLI(BoringModel, run=False) + + # the help messages should match + assert shorthand_help.getvalue() == classpath_help.getvalue() + # make sure it's not empty + assert "Implements Adam" in shorthand_help.getvalue()