diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index d94fbfd4a..92cdba054 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -45,7 +45,7 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non print(f"Loading nlp model from {config_path}") nlp_config = Config().from_disk(config_path) print(f"config: {nlp_config}") - nlp, _ = util.load_model_from_config(nlp_config) + nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True) # ensure the nlp object was defined with a textcat component if "textcat" not in nlp.pipe_names: diff --git a/spacy/cli/debug_model.py b/spacy/cli/debug_model.py index 936a7492e..88e060238 100644 --- a/spacy/cli/debug_model.py +++ b/spacy/cli/debug_model.py @@ -8,6 +8,7 @@ import typer from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from .. import util from ..lang.en import English +from ..util import dot_to_object @debug_cli.command("model") @@ -60,16 +61,7 @@ def debug_model_cli( msg.info(f"Fixing random seed: {seed}") fix_random_seed(seed) - component = config - parts = section.split(".") - for item in parts: - try: - component = component[item] - except KeyError: - msg.fail( - f"The section '{section}' is not a valid section in the provided config.", - exits=1, - ) + component = dot_to_object(config, section) if hasattr(component, "model"): model = component.model else: diff --git a/spacy/errors.py b/spacy/errors.py index df6f82757..862cc71d8 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -592,7 +592,7 @@ class Errors: "for the `nlp` pipeline with components {names}.") E993 = ("The config for 'nlp' needs to include a key 'lang' specifying " "the code of the language to initialize it with (for example " - "'en' for English).\n\n{config}") + "'en' for English) - this can't be 'None'.\n\n{config}") E996 = ("Could not parse {file}: {msg}") E997 = ("Tokenizer special cases are not allowed to modify the text. " "This would map '{chunk}' to '{orth}' given token attributes " diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index 65c33c54a..2577ec675 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -2,7 +2,13 @@ import pytest from .util import get_random_doc -from spacy.util import minibatch_by_words +from spacy import util +from spacy.util import minibatch_by_words, dot_to_object +from thinc.api import Config, Optimizer + +from ..lang.en import English +from ..lang.nl import Dutch +from ..language import DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -56,3 +62,50 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches): minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False) ) assert [len(batch) for batch in batches] == expected_batches + + +def test_util_dot_section(): + cfg_string = """ + [nlp] + lang = "en" + pipeline = ["textcat"] + load_vocab_data = false + + [components] + + [components.textcat] + factory = "textcat" + + [components.textcat.model] + @architectures = "spacy.TextCatBOW.v1" + exclusive_classes = true + ngram_size = 1 + no_output_layer = false + """ + nlp_config = Config().from_str(cfg_string) + en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True) + print(en_config) + + default_config = Config().from_disk(DEFAULT_CONFIG_PATH) + default_config["nlp"]["lang"] = "nl" + nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True) + + # Test that creation went OK + assert isinstance(en_nlp, English) + assert isinstance(nl_nlp, Dutch) + assert nl_nlp.pipe_names == [] + assert en_nlp.pipe_names == ["textcat"] + assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes + + # Test that default values got overwritten + assert not en_config["nlp"]["load_vocab_data"] + assert nl_config["nlp"]["load_vocab_data"] # default value True + + # Test proper functioning of 'dot_to_object' + with pytest.raises(KeyError): + obj = dot_to_object(en_config, "nlp.pipeline.tagger") + with pytest.raises(KeyError): + obj = dot_to_object(en_config, "nlp.unknownattribute") + assert not dot_to_object(en_config, "nlp.load_vocab_data") + assert dot_to_object(nl_config, "nlp.load_vocab_data") + assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer) diff --git a/spacy/util.py b/spacy/util.py index c98ce2354..5d8b70961 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -258,7 +258,7 @@ def load_model_from_config( if "nlp" not in config: raise ValueError(Errors.E985.format(config=config)) nlp_config = config["nlp"] - if "lang" not in nlp_config: + if "lang" not in nlp_config or nlp_config["lang"] is None: raise ValueError(Errors.E993.format(config=nlp_config)) # This will automatically handle all codes registered via the languages # registry, including custom subclasses provided via entry points @@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]: return {".".join(key): value for key, value in walk_dict(obj)} +def dot_to_object(config: Config, section: str): + """Convert dot notation of a "section" to a specific part of the Config. + e.g. "training.optimizer" would return the Optimizer object. + Throws an error if the section is not defined in this config. + + config (Config): The config. + section (str): The dot notation of the section in the config. + RETURNS: The object denoted by the section + """ + component = config + parts = section.split(".") + for item in parts: + try: + component = component[item] + except (KeyError, TypeError) as e: + msg = f"The section '{section}' is not a valid section in the provided config." + raise KeyError(msg) + return component + + def walk_dict( node: Dict[str, Any], parent: List[str] = [] ) -> Iterator[Tuple[List[str], Any]]: