Merge pull request #6096 from explosion/feature/config-overrides-env-vars

This commit is contained in:
Ines Montani 2020-09-21 14:46:19 +02:00 committed by GitHub
commit 4b79d697ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 8 deletions

View File

@ -7,13 +7,15 @@ import srsly
import hashlib import hashlib
import typer import typer
from click import NoSuchOption from click import NoSuchOption
from click.parser import split_arg_string
from typer.main import get_command from typer.main import get_command
from contextlib import contextmanager from contextlib import contextmanager
from thinc.config import Config, ConfigValidationError from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError from configparser import InterpolationError
import os
from ..schemas import ProjectConfigSchema, validate from ..schemas import ProjectConfigSchema, validate
from ..util import import_file, run_command, make_tempdir, registry from ..util import import_file, run_command, make_tempdir, registry, logger
if TYPE_CHECKING: if TYPE_CHECKING:
from pathy import Pathy # noqa: F401 from pathy import Pathy # noqa: F401
@ -37,6 +39,7 @@ commands to check and validate your config files, training and evaluation data,
and custom model implementations. and custom model implementations.
""" """
INIT_HELP = """Commands for initializing configs and pipeline packages.""" INIT_HELP = """Commands for initializing configs and pipeline packages."""
OVERRIDES_ENV_VAR = "SPACY_CONFIG_OVERRIDES"
# Wrappers for Typer's annotations. Initially created to set defaults and to # Wrappers for Typer's annotations. Initially created to set defaults and to
# keep the names short, but not needed at the moment. # keep the names short, but not needed at the moment.
@ -61,24 +64,41 @@ def setup_cli() -> None:
command(prog_name=COMMAND) command(prog_name=COMMAND)
def parse_config_overrides(args: List[str]) -> Dict[str, Any]: def parse_config_overrides(
args: List[str], env_var: Optional[str] = OVERRIDES_ENV_VAR
) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on the extra arguments """Generate a dictionary of config overrides based on the extra arguments
provided on the CLI, e.g. --training.batch_size to override provided on the CLI, e.g. --training.batch_size to override
"training.batch_size". Arguments without a "." are considered invalid, "training.batch_size". Arguments without a "." are considered invalid,
since the config only allows top-level sections to exist. since the config only allows top-level sections to exist.
args (List[str]): The extra arguments from the command line. env_vars (Optional[str]): Optional environment variable to read from.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
""" """
env_string = os.environ.get(env_var, "") if env_var else ""
env_overrides = _parse_overrides(split_arg_string(env_string))
cli_overrides = _parse_overrides(args, is_cli=True)
if cli_overrides:
keys = [k for k in cli_overrides if k not in env_overrides]
logger.debug(f"Config overrides from CLI: {keys}")
if env_overrides:
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
return {**cli_overrides, **env_overrides}
def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
result = {} result = {}
while args: while args:
opt = args.pop(0) opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'" err = f"Invalid config override '{opt}'"
if opt.startswith("--"): # new argument if opt.startswith("--"): # new argument
orig_opt = opt orig_opt = opt
opt = opt.replace("--", "") opt = opt.replace("--", "")
if "." not in opt: if "." not in opt:
if is_cli:
raise NoSuchOption(orig_opt) raise NoSuchOption(orig_opt)
else:
msg.fail(f"{err}: can't override top-level sections", exits=1)
if "=" in opt: # we have --opt=value if "=" in opt: # we have --opt=value
opt, value = opt.split("=", 1) opt, value = opt.split("=", 1)
opt = opt.replace("-", "_") opt = opt.replace("-", "_")
@ -97,7 +117,7 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
except ValueError: except ValueError:
result[opt] = str(value) result[opt] = str(value)
else: else:
msg.fail(f"{err}: override option should start with --", exits=1) msg.fail(f"{err}: name should start with --", exits=1)
return result return result

View File

@ -1,15 +1,15 @@
import pytest import pytest
from click import NoSuchOption from click import NoSuchOption
from spacy.training import docs_to_json, biluo_tags_from_offsets from spacy.training import docs_to_json, biluo_tags_from_offsets
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
import srsly import srsly
import os
from .util import make_tempdir from .util import make_tempdir
@ -341,6 +341,24 @@ def test_parse_config_overrides_invalid_2(args):
parse_config_overrides(args) parse_config_overrides(args)
def test_parse_cli_overrides():
os.environ[OVERRIDES_ENV_VAR] = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello"
result = parse_config_overrides([])
assert len(result) == 4
assert result["x.foo"] == "bar"
assert result["x.bar"] == 12
assert result["x.baz"] is False
assert result["y.foo"] == "hello"
os.environ[OVERRIDES_ENV_VAR] = "--x"
assert parse_config_overrides([], env_var=None) == {}
with pytest.raises(SystemExit):
parse_config_overrides([])
os.environ[OVERRIDES_ENV_VAR] = "hello world"
with pytest.raises(SystemExit):
parse_config_overrides([])
del os.environ[OVERRIDES_ENV_VAR]
@pytest.mark.parametrize("lang", ["en", "nl"]) @pytest.mark.parametrize("lang", ["en", "nl"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]] "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]