util function dot_to_object and corresponding unit test

This commit is contained in:
svlandeg 2020-07-27 17:50:12 +02:00
parent 674c39bff9
commit 61068e0fb1
5 changed files with 79 additions and 14 deletions

View File

@ -45,7 +45,7 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
print(f"Loading nlp model from {config_path}") print(f"Loading nlp model from {config_path}")
nlp_config = Config().from_disk(config_path) nlp_config = Config().from_disk(config_path)
print(f"config: {nlp_config}") print(f"config: {nlp_config}")
nlp, _ = util.load_model_from_config(nlp_config) nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True)
# ensure the nlp object was defined with a textcat component # ensure the nlp object was defined with a textcat component
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:

View File

@ -8,6 +8,7 @@ import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
from .. import util from .. import util
from ..lang.en import English from ..lang.en import English
from ..util import dot_to_object
@debug_cli.command("model") @debug_cli.command("model")
@ -60,16 +61,7 @@ def debug_model_cli(
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
component = config component = dot_to_object(config, section)
parts = section.split(".")
for item in parts:
try:
component = component[item]
except KeyError:
msg.fail(
f"The section '{section}' is not a valid section in the provided config.",
exits=1,
)
if hasattr(component, "model"): if hasattr(component, "model"):
model = component.model model = component.model
else: else:

View File

@ -592,7 +592,7 @@ class Errors:
"for the `nlp` pipeline with components {names}.") "for the `nlp` pipeline with components {names}.")
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying " E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
"the code of the language to initialize it with (for example " "the code of the language to initialize it with (for example "
"'en' for English).\n\n{config}") "'en' for English) - this can't be 'None'.\n\n{config}")
E996 = ("Could not parse {file}: {msg}") E996 = ("Could not parse {file}: {msg}")
E997 = ("Tokenizer special cases are not allowed to modify the text. " E997 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes " "This would map '{chunk}' to '{orth}' given token attributes "

View File

@ -2,7 +2,13 @@ import pytest
from .util import get_random_doc from .util import get_random_doc
from spacy.util import minibatch_by_words from spacy import util
from spacy.util import minibatch_by_words, dot_to_object
from thinc.api import Config, Optimizer
from ..lang.en import English
from ..lang.nl import Dutch
from ..language import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -56,3 +62,50 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches):
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False) minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
) )
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches
def test_util_dot_section():
cfg_string = """
[nlp]
lang = "en"
pipeline = ["textcat"]
load_vocab_data = false
[components]
[components.textcat]
factory = "textcat"
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
"""
nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
print(en_config)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK
assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch)
assert nl_nlp.pipe_names == []
assert en_nlp.pipe_names == ["textcat"]
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes
# Test that default values got overwritten
assert not en_config["nlp"]["load_vocab_data"]
assert nl_config["nlp"]["load_vocab_data"] # default value True
# Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.pipeline.tagger")
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.unknownattribute")
assert not dot_to_object(en_config, "nlp.load_vocab_data")
assert dot_to_object(nl_config, "nlp.load_vocab_data")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)

View File

@ -258,7 +258,7 @@ def load_model_from_config(
if "nlp" not in config: if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config)) raise ValueError(Errors.E985.format(config=config))
nlp_config = config["nlp"] nlp_config = config["nlp"]
if "lang" not in nlp_config: if "lang" not in nlp_config or nlp_config["lang"] is None:
raise ValueError(Errors.E993.format(config=nlp_config)) raise ValueError(Errors.E993.format(config=nlp_config))
# This will automatically handle all codes registered via the languages # This will automatically handle all codes registered via the languages
# registry, including custom subclasses provided via entry points # registry, including custom subclasses provided via entry points
@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
return {".".join(key): value for key, value in walk_dict(obj)} return {".".join(key): value for key, value in walk_dict(obj)}
def dot_to_object(config: Config, section: str):
"""Convert dot notation of a "section" to a specific part of the Config.
e.g. "training.optimizer" would return the Optimizer object.
Throws an error if the section is not defined in this config.
config (Config): The config.
section (str): The dot notation of the section in the config.
RETURNS: The object denoted by the section
"""
component = config
parts = section.split(".")
for item in parts:
try:
component = component[item]
except (KeyError, TypeError) as e:
msg = f"The section '{section}' is not a valid section in the provided config."
raise KeyError(msg)
return component
def walk_dict( def walk_dict(
node: Dict[str, Any], parent: List[str] = [] node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]: ) -> Iterator[Tuple[List[str], Any]]: