mirror of https://github.com/explosion/spaCy.git
Merge pull request #6096 from explosion/feature/config-overrides-env-vars
This commit is contained in:
commit
4b79d697ee
|
@ -7,13 +7,15 @@ import srsly
|
||||||
import hashlib
|
import hashlib
|
||||||
import typer
|
import typer
|
||||||
from click import NoSuchOption
|
from click import NoSuchOption
|
||||||
|
from click.parser import split_arg_string
|
||||||
from typer.main import get_command
|
from typer.main import get_command
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from thinc.config import Config, ConfigValidationError
|
from thinc.config import Config, ConfigValidationError
|
||||||
from configparser import InterpolationError
|
from configparser import InterpolationError
|
||||||
|
import os
|
||||||
|
|
||||||
from ..schemas import ProjectConfigSchema, validate
|
from ..schemas import ProjectConfigSchema, validate
|
||||||
from ..util import import_file, run_command, make_tempdir, registry
|
from ..util import import_file, run_command, make_tempdir, registry, logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathy import Pathy # noqa: F401
|
from pathy import Pathy # noqa: F401
|
||||||
|
@ -37,6 +39,7 @@ commands to check and validate your config files, training and evaluation data,
|
||||||
and custom model implementations.
|
and custom model implementations.
|
||||||
"""
|
"""
|
||||||
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
||||||
|
OVERRIDES_ENV_VAR = "SPACY_CONFIG_OVERRIDES"
|
||||||
|
|
||||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||||
# keep the names short, but not needed at the moment.
|
# keep the names short, but not needed at the moment.
|
||||||
|
@ -61,24 +64,41 @@ def setup_cli() -> None:
|
||||||
command(prog_name=COMMAND)
|
command(prog_name=COMMAND)
|
||||||
|
|
||||||
|
|
||||||
def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
|
def parse_config_overrides(
|
||||||
|
args: List[str], env_var: Optional[str] = OVERRIDES_ENV_VAR
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Generate a dictionary of config overrides based on the extra arguments
|
"""Generate a dictionary of config overrides based on the extra arguments
|
||||||
provided on the CLI, e.g. --training.batch_size to override
|
provided on the CLI, e.g. --training.batch_size to override
|
||||||
"training.batch_size". Arguments without a "." are considered invalid,
|
"training.batch_size". Arguments without a "." are considered invalid,
|
||||||
since the config only allows top-level sections to exist.
|
since the config only allows top-level sections to exist.
|
||||||
|
|
||||||
args (List[str]): The extra arguments from the command line.
|
env_vars (Optional[str]): Optional environment variable to read from.
|
||||||
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
||||||
"""
|
"""
|
||||||
|
env_string = os.environ.get(env_var, "") if env_var else ""
|
||||||
|
env_overrides = _parse_overrides(split_arg_string(env_string))
|
||||||
|
cli_overrides = _parse_overrides(args, is_cli=True)
|
||||||
|
if cli_overrides:
|
||||||
|
keys = [k for k in cli_overrides if k not in env_overrides]
|
||||||
|
logger.debug(f"Config overrides from CLI: {keys}")
|
||||||
|
if env_overrides:
|
||||||
|
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
|
||||||
|
return {**cli_overrides, **env_overrides}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
while args:
|
while args:
|
||||||
opt = args.pop(0)
|
opt = args.pop(0)
|
||||||
err = f"Invalid CLI argument '{opt}'"
|
err = f"Invalid config override '{opt}'"
|
||||||
if opt.startswith("--"): # new argument
|
if opt.startswith("--"): # new argument
|
||||||
orig_opt = opt
|
orig_opt = opt
|
||||||
opt = opt.replace("--", "")
|
opt = opt.replace("--", "")
|
||||||
if "." not in opt:
|
if "." not in opt:
|
||||||
raise NoSuchOption(orig_opt)
|
if is_cli:
|
||||||
|
raise NoSuchOption(orig_opt)
|
||||||
|
else:
|
||||||
|
msg.fail(f"{err}: can't override top-level sections", exits=1)
|
||||||
if "=" in opt: # we have --opt=value
|
if "=" in opt: # we have --opt=value
|
||||||
opt, value = opt.split("=", 1)
|
opt, value = opt.split("=", 1)
|
||||||
opt = opt.replace("-", "_")
|
opt = opt.replace("-", "_")
|
||||||
|
@ -97,7 +117,7 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
result[opt] = str(value)
|
result[opt] = str(value)
|
||||||
else:
|
else:
|
||||||
msg.fail(f"{err}: override option should start with --", exits=1)
|
msg.fail(f"{err}: name should start with --", exits=1)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
import pytest
|
import pytest
|
||||||
from click import NoSuchOption
|
from click import NoSuchOption
|
||||||
|
|
||||||
from spacy.training import docs_to_json, biluo_tags_from_offsets
|
from spacy.training import docs_to_json, biluo_tags_from_offsets
|
||||||
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
|
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
|
||||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||||
from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
||||||
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
||||||
from spacy.cli._util import load_project_config, substitute_project_variables
|
from spacy.cli._util import load_project_config, substitute_project_variables
|
||||||
from spacy.cli._util import string_to_list
|
from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR
|
||||||
from thinc.config import ConfigValidationError
|
from thinc.config import ConfigValidationError
|
||||||
import srsly
|
import srsly
|
||||||
|
import os
|
||||||
|
|
||||||
from .util import make_tempdir
|
from .util import make_tempdir
|
||||||
|
|
||||||
|
@ -341,6 +341,24 @@ def test_parse_config_overrides_invalid_2(args):
|
||||||
parse_config_overrides(args)
|
parse_config_overrides(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_cli_overrides():
|
||||||
|
os.environ[OVERRIDES_ENV_VAR] = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello"
|
||||||
|
result = parse_config_overrides([])
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result["x.foo"] == "bar"
|
||||||
|
assert result["x.bar"] == 12
|
||||||
|
assert result["x.baz"] is False
|
||||||
|
assert result["y.foo"] == "hello"
|
||||||
|
os.environ[OVERRIDES_ENV_VAR] = "--x"
|
||||||
|
assert parse_config_overrides([], env_var=None) == {}
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parse_config_overrides([])
|
||||||
|
os.environ[OVERRIDES_ENV_VAR] = "hello world"
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parse_config_overrides([])
|
||||||
|
del os.environ[OVERRIDES_ENV_VAR]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("lang", ["en", "nl"])
|
@pytest.mark.parametrize("lang", ["en", "nl"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
|
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
|
||||||
|
|
Loading…
Reference in New Issue