LightningCLI changes for jsonargparse>=4.0.0 (#10426)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Mauricio Villegas 2021-11-19 18:03:14 +01:00 committed by GitHub
parent ff8ac6e2e1
commit 5d748e560b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 17 deletions

View File

@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))

View File

@ -14,7 +14,6 @@
import inspect
import os
import sys
from argparse import Namespace
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from unittest import mock
@ -32,13 +31,12 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple
if _JSONARGPARSE_AVAILABLE:
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
from jsonargparse.actions import _ActionSubCommands
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
from jsonargparse.optionals import import_docstring_parse
set_config_read_mode(fsspec_enabled=True)
else:
ArgumentParser = object
ArgumentParser = Namespace = object
class _Registry(dict):
@ -100,7 +98,7 @@ class LightningArgumentParser(ArgumentParser):
# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
For full details of accepted arguments see `ArgumentParser.__init__
@ -109,9 +107,9 @@ class LightningArgumentParser(ArgumentParser):
if not _JSONARGPARSE_AVAILABLE:
raise ModuleNotFoundError(
"`jsonargparse` is not installed but it is required for the CLI."
" Install it with `pip install jsonargparse[signatures]`."
" Install it with `pip install -U jsonargparse[signatures]`."
)
super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
super().__init__(*args, **kwargs)
self.add_argument(
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
@ -363,7 +361,7 @@ class SaveConfigCallback(Callback):
def __init__(
self,
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config: Namespace,
config_filename: str,
overwrite: bool = False,
multifile: bool = False,
@ -671,8 +669,7 @@ class LightningCLI:
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)]
action_subcommand = action_subcommands[0]
action_subcommand = self.parser._subcommands_action
return action_subcommand._name_parser_map[subcommand]
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
@ -772,12 +769,16 @@ class LightningCLI:
return fn_kwargs
def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]:
def _global_add_class_path(
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
) -> Dict[str, Any]:
if isinstance(init_args, Namespace):
init_args = init_args.as_dict()
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}
def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]:
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
def add_class_path(init_args: Namespace) -> Dict[str, Any]:
return _global_add_class_path(class_type, init_args)
return add_class_path

View File

@ -85,7 +85,7 @@ _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.grou
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") and _compare_version("jsonargparse", operator.ge, "4.0.0")
_KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
_NEPTUNE_AVAILABLE = _module_available("neptune")
_NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")

View File

@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
torchtext>=0.8.*
omegaconf>=2.0.5
hydra-core>=1.0.5
jsonargparse[signatures]>=3.19.3
jsonargparse[signatures]>=4.0.0
gcsfs>=2021.5.0
rich>=10.2.2

View File

@ -348,7 +348,7 @@ def test_lightning_cli_args(tmpdir):
loaded_config = yaml.safe_load(f.read())
loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"]
cli_config = cli.config["fit"].as_dict()
assert cli_config["seed_everything"] == 1234
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
@ -404,7 +404,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
loaded_config = yaml.safe_load(f.read())
loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"]
cli_config = cli.config["fit"].as_dict()
assert loaded_config["model"] == cli_config["model"]
assert loaded_config["data"] == cli_config["data"]