Improve checks for sourced components (#7490)

* Improve checks for sourced components

* Remove language class checks

* Convert python warning to logger warning

* Remove unused warning

* Fix formatting
This commit is contained in:
Adriane Boyd 2021-04-19 10:36:32 +02:00 committed by GitHub
parent 05bdbe28bb
commit 1ad646cbcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 4 deletions

View File

@ -159,6 +159,8 @@ class Warnings:
"http://spacy.io/usage/v3#jupyter-notebook-gpu") "http://spacy.io/usage/v3#jupyter-notebook-gpu")
W112 = ("The model specified to use for initial vectors ({name}) has no " W112 = ("The model specified to use for initial vectors ({name}) has no "
"vectors. This is almost certainly a mistake.") "vectors. This is almost certainly a mistake.")
W113 = ("Sourced component '{name}' may not work as expected: source "
"vectors are not identical to current pipeline vectors.")
@add_codes @add_codes
@ -651,8 +653,8 @@ class Errors:
"returned the initialized nlp object instead?") "returned the initialized nlp object instead?")
E944 = ("Can't copy pipeline component '{name}' from source '{model}': " E944 = ("Can't copy pipeline component '{name}' from source '{model}': "
"not found in pipeline. Available components: {opts}") "not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " E945 = ("Can't copy pipeline component '{name}' from source. Expected "
"nlp object, but got: {source}") "loaded nlp object, but got: {source}")
E947 = ("`Matcher.add` received invalid `greedy` argument: expected " E947 = ("`Matcher.add` received invalid `greedy` argument: expected "
"a string value from {expected} but got: '{arg}'") "a string value from {expected} but got: '{arg}'")
E948 = ("`Matcher.add` received invalid 'patterns' argument: expected " E948 = ("`Matcher.add` received invalid 'patterns' argument: expected "

View File

@ -682,9 +682,14 @@ class Language:
name (str): Optional alternative name to use in current pipeline. name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name. RETURNS (Tuple[Callable, str]): The component and its factory name.
""" """
# TODO: handle errors and mismatches (vectors etc.) # Check source type
if not isinstance(source, self.__class__): if not isinstance(source, Language):
raise ValueError(Errors.E945.format(name=source_name, source=type(source))) raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
# Check vectors, with faster checks first
if self.vocab.vectors.shape != source.vocab.vectors.shape or \
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes():
util.logger.warning(Warnings.W113.format(name=source_name))
if not source_name in source.component_names: if not source_name in source.component_names:
raise KeyError( raise KeyError(
Errors.E944.format( Errors.E944.format(

View File

@ -1,4 +1,6 @@
import pytest import pytest
import mock
import logging
from spacy.language import Language from spacy.language import Language
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
@ -402,6 +404,38 @@ def test_pipe_factories_from_source():
nlp.add_pipe("custom", source=source_nlp) nlp.add_pipe("custom", source=source_nlp)
def test_pipe_factories_from_source_language_subclass():
class CustomEnglishDefaults(English.Defaults):
stop_words = set(["custom", "stop"])
@registry.languages("custom_en")
class CustomEnglish(English):
lang = "custom_en"
Defaults = CustomEnglishDefaults
source_nlp = English()
source_nlp.add_pipe("tagger")
# custom subclass
nlp = CustomEnglish()
nlp.add_pipe("tagger", source=source_nlp)
assert "tagger" in nlp.pipe_names
# non-subclass
nlp = German()
nlp.add_pipe("tagger", source=source_nlp)
assert "tagger" in nlp.pipe_names
# mismatched vectors
nlp = English()
nlp.vocab.vectors.resize((1, 4))
nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4])
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
nlp.add_pipe("tagger", source=source_nlp)
mock_warning.assert_called()
def test_pipe_factories_from_source_custom(): def test_pipe_factories_from_source_custom():
"""Test adding components from a source model with custom components.""" """Test adding components from a source model with custom components."""
name = "test_pipe_factories_from_source_custom" name = "test_pipe_factories_from_source_custom"