LightningCLI changes for jsonargparse>=4.0.0 (#10426)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
ff8ac6e2e1
commit
5d748e560b
|
@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
|
||||
|
||||
|
||||
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))
|
||||
|
||||
|
||||
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from types import MethodType, ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from unittest import mock
|
||||
|
@ -32,13 +31,12 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple
|
||||
|
||||
if _JSONARGPARSE_AVAILABLE:
|
||||
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
|
||||
from jsonargparse.actions import _ActionSubCommands
|
||||
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
|
||||
from jsonargparse.optionals import import_docstring_parse
|
||||
|
||||
set_config_read_mode(fsspec_enabled=True)
|
||||
else:
|
||||
ArgumentParser = object
|
||||
ArgumentParser = Namespace = object
|
||||
|
||||
|
||||
class _Registry(dict):
|
||||
|
@ -100,7 +98,7 @@ class LightningArgumentParser(ArgumentParser):
|
|||
# use class attribute because `parse_args` is only called on the main parser
|
||||
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}
|
||||
|
||||
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input.
|
||||
|
||||
For full details of accepted arguments see `ArgumentParser.__init__
|
||||
|
@ -109,9 +107,9 @@ class LightningArgumentParser(ArgumentParser):
|
|||
if not _JSONARGPARSE_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
"`jsonargparse` is not installed but it is required for the CLI."
|
||||
" Install it with `pip install jsonargparse[signatures]`."
|
||||
" Install it with `pip install -U jsonargparse[signatures]`."
|
||||
)
|
||||
super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_argument(
|
||||
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
|
||||
)
|
||||
|
@ -363,7 +361,7 @@ class SaveConfigCallback(Callback):
|
|||
def __init__(
|
||||
self,
|
||||
parser: LightningArgumentParser,
|
||||
config: Union[Namespace, Dict[str, Any]],
|
||||
config: Namespace,
|
||||
config_filename: str,
|
||||
overwrite: bool = False,
|
||||
multifile: bool = False,
|
||||
|
@ -671,8 +669,7 @@ class LightningCLI:
|
|||
if subcommand is None:
|
||||
return self.parser
|
||||
# return the subcommand parser for the subcommand passed
|
||||
action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)]
|
||||
action_subcommand = action_subcommands[0]
|
||||
action_subcommand = self.parser._subcommands_action
|
||||
return action_subcommand._name_parser_map[subcommand]
|
||||
|
||||
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
|
@ -772,12 +769,16 @@ class LightningCLI:
|
|||
return fn_kwargs
|
||||
|
||||
|
||||
def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def _global_add_class_path(
|
||||
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(init_args, Namespace):
|
||||
init_args = init_args.as_dict()
|
||||
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}
|
||||
|
||||
|
||||
def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
|
||||
def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
|
||||
def add_class_path(init_args: Namespace) -> Dict[str, Any]:
|
||||
return _global_add_class_path(class_type, init_args)
|
||||
|
||||
return add_class_path
|
||||
|
|
|
@ -85,7 +85,7 @@ _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.grou
|
|||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
||||
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
|
||||
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") and _compare_version("jsonargparse", operator.ge, "4.0.0")
|
||||
_KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
|
||||
_NEPTUNE_AVAILABLE = _module_available("neptune")
|
||||
_NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")
|
||||
|
|
|
@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
|
|||
torchtext>=0.8.*
|
||||
omegaconf>=2.0.5
|
||||
hydra-core>=1.0.5
|
||||
jsonargparse[signatures]>=3.19.3
|
||||
jsonargparse[signatures]>=4.0.0
|
||||
gcsfs>=2021.5.0
|
||||
rich>=10.2.2
|
||||
|
|
|
@ -348,7 +348,7 @@ def test_lightning_cli_args(tmpdir):
|
|||
loaded_config = yaml.safe_load(f.read())
|
||||
|
||||
loaded_config = loaded_config["fit"]
|
||||
cli_config = cli.config["fit"]
|
||||
cli_config = cli.config["fit"].as_dict()
|
||||
|
||||
assert cli_config["seed_everything"] == 1234
|
||||
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
|
||||
|
@ -404,7 +404,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
|||
loaded_config = yaml.safe_load(f.read())
|
||||
|
||||
loaded_config = loaded_config["fit"]
|
||||
cli_config = cli.config["fit"]
|
||||
cli_config = cli.config["fit"].as_dict()
|
||||
|
||||
assert loaded_config["model"] == cli_config["model"]
|
||||
assert loaded_config["data"] == cli_config["data"]
|
||||
|
|
Loading…
Reference in New Issue