diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 0159dd473..0dd2ee380 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -7,6 +7,7 @@ import srsly import hashlib import typer from click import NoSuchOption +from click.parser import split_arg_string from typer.main import get_command from contextlib import contextmanager from thinc.config import Config, ConfigValidationError @@ -38,6 +39,7 @@ commands to check and validate your config files, training and evaluation data, and custom model implementations. """ INIT_HELP = """Commands for initializing configs and pipeline packages.""" +OVERRIDES_ENV_VAR = "SPACY_CONFIG_OVERRIDES" # Wrappers for Typer's annotations. Initially created to set defaults and to # keep the names short, but not needed at the moment. @@ -62,46 +64,41 @@ def setup_cli() -> None: command(prog_name=COMMAND) -def parse_config_env_overrides( - *, prefix: str = "SPACY_CONFIG_", dot: str = "__" +def parse_config_overrides( + args: List[str], env_var: Optional[str] = OVERRIDES_ENV_VAR ) -> Dict[str, Any]: - """Generate a dictionary of config overrides based on environment variables, - e.g. SPACY_CONFIG_TRAINING__BATCH_SIZE=123 overrides the training.batch_size - setting. - - prefix (str): The env variable prefix for config overrides. - dot (str): String used to represent the "dot", e.g. in training.batch_size. - RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. - """ - result = {} - for env_key, value in os.environ.items(): - if env_key.startswith(prefix): - opt = env_key[len(prefix) :].lower().replace(dot, ".") - if "." in opt: - result[opt] = try_json_loads(value) - return result - - -def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str, Any]: """Generate a dictionary of config overrides based on the extra arguments provided on the CLI, e.g. --training.batch_size to override "training.batch_size". Arguments without a "." are considered invalid, since the config only allows top-level sections to exist. - args (List[str]): The extra arguments from the command line. - env_vars (bool): Include environment variables. + env_vars (Optional[str]): Optional environment variable to read from. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. """ - env_overrides = parse_config_env_overrides() if env_vars else {} - cli_overrides = {} + env_string = os.environ.get(env_var, "") if env_var else "" + env_overrides = _parse_overrides(split_arg_string(env_string)) + cli_overrides = _parse_overrides(args, is_cli=True) + if cli_overrides: + keys = [k for k in cli_overrides if k not in env_overrides] + logger.debug(f"Config overrides from CLI: {keys}") + if env_overrides: + logger.debug(f"Config overrides from env variables: {list(env_overrides)}") + return {**cli_overrides, **env_overrides} + + +def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]: + result = {} while args: opt = args.pop(0) - err = f"Invalid CLI argument '{opt}'" + err = f"Invalid config override '{opt}'" if opt.startswith("--"): # new argument orig_opt = opt opt = opt.replace("--", "") if "." not in opt: - raise NoSuchOption(orig_opt) + if is_cli: + raise NoSuchOption(orig_opt) + else: + msg.fail(f"{err}: can't override top-level sections", exits=1) if "=" in opt: # we have --opt=value opt, value = opt.split("=", 1) opt = opt.replace("-", "_") @@ -110,27 +107,18 @@ def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str, value = "true" else: value = args.pop(0) - if opt not in env_overrides: - cli_overrides[opt] = try_json_loads(value) + # Just like we do in the config, we're calling json.loads on the + # values. But since they come from the CLI, it'd be unintuitive to + # explicitly mark strings with escaped quotes. So we're working + # around that here by falling back to a string if parsing fails. + # TODO: improve logic to handle simple types like list of strings? + try: + result[opt] = srsly.json_loads(value) + except ValueError: + result[opt] = str(value) else: - msg.fail(f"{err}: override option should start with --", exits=1) - if cli_overrides: - logger.debug(f"Config overrides from CLI: {list(cli_overrides)}") - if env_overrides: - logger.debug(f"Config overrides from env variables: {list(env_overrides)}") - return {**cli_overrides, **env_overrides} - - -def try_json_loads(value: Any) -> Any: - # Just like we do in the config, we're calling json.loads on the - # values. But since they come from the CLI, it'd be unintuitive to - # explicitly mark strings with escaped quotes. So we're working - # around that here by falling back to a string if parsing fails. - # TODO: improve logic to handle simple types like list of strings? - try: - return srsly.json_loads(value) - except ValueError: - return str(value) + msg.fail(f"{err}: name should start with --", exits=1) + return result def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]: diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index d81437f18..a9c9d8ca5 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -6,7 +6,7 @@ from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables -from spacy.cli._util import string_to_list, parse_config_env_overrides +from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR from thinc.config import ConfigValidationError import srsly import os @@ -342,15 +342,21 @@ def test_parse_config_overrides_invalid_2(args): def test_parse_cli_overrides(): - prefix = "SPACY_CONFIG_" - dot = "__" - os.environ[f"{prefix}TRAINING{dot}BATCH_SIZE"] = "123" - os.environ[f"{prefix}FOO{dot}BAR{dot}BAZ"] = "hello" - os.environ[prefix] = "bad" - result = parse_config_env_overrides(prefix=prefix, dot=dot) - assert len(result) == 2 - assert result["training.batch_size"] == 123 - assert result["foo.bar.baz"] == "hello" + os.environ[OVERRIDES_ENV_VAR] = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello" + result = parse_config_overrides([]) + assert len(result) == 4 + assert result["x.foo"] == "bar" + assert result["x.bar"] == 12 + assert result["x.baz"] is False + assert result["y.foo"] == "hello" + os.environ[OVERRIDES_ENV_VAR] = "--x" + assert parse_config_overrides([], env_var=None) == {} + with pytest.raises(SystemExit): + parse_config_overrides([]) + os.environ[OVERRIDES_ENV_VAR] = "hello world" + with pytest.raises(SystemExit): + parse_config_overrides([]) + del os.environ[OVERRIDES_ENV_VAR] @pytest.mark.parametrize("lang", ["en", "nl"])