From 5e3b796b122fc9b1125f350b5dcda625fd9740f0 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 22 Sep 2020 12:24:39 +0200 Subject: [PATCH] Validate section refs in debug config --- spacy/cli/debug_config.py | 27 +++++++++++++++++++++++++-- spacy/tests/test_cli.py | 15 ++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/spacy/cli/debug_config.py b/spacy/cli/debug_config.py index 7930d0674..d07a0bb2d 100644 --- a/spacy/cli/debug_config.py +++ b/spacy/cli/debug_config.py @@ -2,7 +2,7 @@ from typing import Optional, Dict, Any, Union, List from pathlib import Path from wasabi import msg, table from thinc.api import Config -from thinc.config import VARIABLE_RE +from thinc.config import VARIABLE_RE, ConfigValidationError import typer from ._util import Arg, Opt, show_validation_error, parse_config_overrides @@ -51,7 +51,10 @@ def debug_config( msg.divider("Config validation") with show_validation_error(config_path): config = util.load_config(config_path, overrides=overrides) - nlp, _ = util.load_model_from_config(config) + nlp, resolved = util.load_model_from_config(config) + # Use the resolved config here in case user has one function returning + # a dict of corpora etc. + check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"]) msg.good("Config is valid") if show_vars: variables = get_variables(config) @@ -93,3 +96,23 @@ def get_variables(config: Config) -> Dict[str, Any]: value = util.dot_to_object(config, path) result[variable] = repr(value) return result + + +def check_section_refs(config: Config, fields: List[str]) -> None: + """Validate fields in the config that refer to other sections or values + (e.g. in the corpora) and make sure that those references exist. + """ + errors = [] + for field in fields: + # If the field doesn't exist in the config, we ignore it + try: + value = util.dot_to_object(config, field) + except KeyError: + continue + try: + util.dot_to_object(config, value) + except KeyError: + msg = f"not a valid section reference: {value}" + errors.append({"loc": field.split("."), "msg": msg}) + if errors: + raise ConfigValidationError(config, errors) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index a9c9d8ca5..1bc246fef 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -7,7 +7,8 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR -from thinc.config import ConfigValidationError +from spacy.cli.debug_config import check_section_refs +from thinc.config import ConfigValidationError, Config import srsly import os @@ -413,3 +414,15 @@ def test_string_to_list(value): def test_string_to_list_intify(value): assert string_to_list(value, intify=False) == ["1", "2", "3"] assert string_to_list(value, intify=True) == [1, 2, 3] + + +def test_check_section_refs(): + config = {"a": {"b": {"c": "a.d.e"}, "d": {"e": 1}}, "f": {"g": "d.f"}} + config = Config(config) + # Valid section reference + check_section_refs(config, ["a.b.c"]) + # Section that doesn't exist in this config + check_section_refs(config, ["x.y.z"]) + # Invalid section reference + with pytest.raises(ConfigValidationError): + check_section_refs(config, ["a.b.c", "f.g"])