From 5d748e560b72ba3d3c93de683548a67a8426e29c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 19 Nov 2021 18:03:14 +0100 Subject: [PATCH] LightningCLI changes for jsonargparse>=4.0.0 (#10426) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: thomas chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/cli.py | 27 +++++++++++++------------- pytorch_lightning/utilities/imports.py | 2 +- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 4 ++-- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e52442c7b..0f5f3644b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426)) + + - Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6ed485257f..b08ad7265c 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -14,7 +14,6 @@ import inspect import os import sys -from argparse import Namespace from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from unittest import mock @@ -32,13 +31,12 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple if _JSONARGPARSE_AVAILABLE: - from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode - from jsonargparse.actions import _ActionSubCommands + from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode from jsonargparse.optionals import import_docstring_parse set_config_read_mode(fsspec_enabled=True) else: - ArgumentParser = object + ArgumentParser = Namespace = object class _Registry(dict): @@ -100,7 +98,7 @@ class LightningArgumentParser(ArgumentParser): # use class attribute because `parse_args` is only called on the main parser _choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {} - def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ @@ -109,9 +107,9 @@ class LightningArgumentParser(ArgumentParser): if not _JSONARGPARSE_AVAILABLE: raise ModuleNotFoundError( "`jsonargparse` is not installed but it is required for the CLI." - " Install it with `pip install jsonargparse[signatures]`." + " Install it with `pip install -U jsonargparse[signatures]`." ) - super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + super().__init__(*args, **kwargs) self.add_argument( "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) @@ -363,7 +361,7 @@ class SaveConfigCallback(Callback): def __init__( self, parser: LightningArgumentParser, - config: Union[Namespace, Dict[str, Any]], + config: Namespace, config_filename: str, overwrite: bool = False, multifile: bool = False, @@ -671,8 +669,7 @@ class LightningCLI: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed - action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)] - action_subcommand = action_subcommands[0] + action_subcommand = self.parser._subcommands_action return action_subcommand._name_parser_map[subcommand] def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: @@ -772,12 +769,16 @@ class LightningCLI: return fn_kwargs -def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]: +def _global_add_class_path( + class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None +) -> Dict[str, Any]: + if isinstance(init_args, Namespace): + init_args = init_args.as_dict() return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}} -def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]: +def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: + def add_class_path(init_args: Namespace) -> Dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 5db24fe0f5..aa6349b5d6 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -85,7 +85,7 @@ _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.grou _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") -_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") +_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") and _compare_version("jsonargparse", operator.ge, "4.0.0") _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() _NEPTUNE_AVAILABLE = _module_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") diff --git a/requirements/extra.txt b/requirements/extra.txt index 4aea9dad9c..6abf3089b8 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta torchtext>=0.8.* omegaconf>=2.0.5 hydra-core>=1.0.5 -jsonargparse[signatures]>=3.19.3 +jsonargparse[signatures]>=4.0.0 gcsfs>=2021.5.0 rich>=10.2.2 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5f824d1bee..1d6146f16e 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -348,7 +348,7 @@ def test_lightning_cli_args(tmpdir): loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] - cli_config = cli.config["fit"] + cli_config = cli.config["fit"].as_dict() assert cli_config["seed_everything"] == 1234 assert "model" not in loaded_config and "model" not in cli_config # no arguments to include @@ -404,7 +404,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] - cli_config = cli.config["fit"] + cli_config = cli.config["fit"].as_dict() assert loaded_config["model"] == cli_config["model"] assert loaded_config["data"] == cli_config["data"]