From 1ad646cbcf0015cb3b944f98bef1b3a9eeb54e9f Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 19 Apr 2021 10:36:32 +0200 Subject: [PATCH] Improve checks for sourced components (#7490) * Improve checks for sourced components * Remove language class checks * Convert python warning to logger warning * Remove unused warning * Fix formatting --- spacy/errors.py | 6 ++-- spacy/language.py | 9 ++++-- spacy/tests/pipeline/test_pipe_factories.py | 34 +++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index e4e331d42..453e98b59 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -159,6 +159,8 @@ class Warnings: "http://spacy.io/usage/v3#jupyter-notebook-gpu") W112 = ("The model specified to use for initial vectors ({name}) has no " "vectors. This is almost certainly a mistake.") + W113 = ("Sourced component '{name}' may not work as expected: source " + "vectors are not identical to current pipeline vectors.") @add_codes @@ -651,8 +653,8 @@ class Errors: "returned the initialized nlp object instead?") E944 = ("Can't copy pipeline component '{name}' from source '{model}': " "not found in pipeline. Available components: {opts}") - E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " - "nlp object, but got: {source}") + E945 = ("Can't copy pipeline component '{name}' from source. Expected " + "loaded nlp object, but got: {source}") E947 = ("`Matcher.add` received invalid `greedy` argument: expected " "a string value from {expected} but got: '{arg}'") E948 = ("`Matcher.add` received invalid 'patterns' argument: expected " diff --git a/spacy/language.py b/spacy/language.py index 68bd3cd4c..6f6470533 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -682,9 +682,14 @@ class Language: name (str): Optional alternative name to use in current pipeline. RETURNS (Tuple[Callable, str]): The component and its factory name. """ - # TODO: handle errors and mismatches (vectors etc.) - if not isinstance(source, self.__class__): + # Check source type + if not isinstance(source, Language): raise ValueError(Errors.E945.format(name=source_name, source=type(source))) + # Check vectors, with faster checks first + if self.vocab.vectors.shape != source.vocab.vectors.shape or \ + self.vocab.vectors.key2row != source.vocab.vectors.key2row or \ + self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes(): + util.logger.warning(Warnings.W113.format(name=source_name)) if not source_name in source.component_names: raise KeyError( Errors.E944.format( diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py index e1706ffb1..a7071abfd 100644 --- a/spacy/tests/pipeline/test_pipe_factories.py +++ b/spacy/tests/pipeline/test_pipe_factories.py @@ -1,4 +1,6 @@ import pytest +import mock +import logging from spacy.language import Language from spacy.lang.en import English from spacy.lang.de import German @@ -402,6 +404,38 @@ def test_pipe_factories_from_source(): nlp.add_pipe("custom", source=source_nlp) +def test_pipe_factories_from_source_language_subclass(): + class CustomEnglishDefaults(English.Defaults): + stop_words = set(["custom", "stop"]) + + @registry.languages("custom_en") + class CustomEnglish(English): + lang = "custom_en" + Defaults = CustomEnglishDefaults + + source_nlp = English() + source_nlp.add_pipe("tagger") + + # custom subclass + nlp = CustomEnglish() + nlp.add_pipe("tagger", source=source_nlp) + assert "tagger" in nlp.pipe_names + + # non-subclass + nlp = German() + nlp.add_pipe("tagger", source=source_nlp) + assert "tagger" in nlp.pipe_names + + # mismatched vectors + nlp = English() + nlp.vocab.vectors.resize((1, 4)) + nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4]) + logger = logging.getLogger("spacy") + with mock.patch.object(logger, "warning") as mock_warning: + nlp.add_pipe("tagger", source=source_nlp) + mock_warning.assert_called() + + def test_pipe_factories_from_source_custom(): """Test adding components from a source model with custom components.""" name = "test_pipe_factories_from_source_custom"