mirror of https://github.com/explosion/spaCy.git
Improve checks for sourced components (#7490)
* Improve checks for sourced components * Remove language class checks * Convert python warning to logger warning * Remove unused warning * Fix formatting
This commit is contained in:
parent
05bdbe28bb
commit
1ad646cbcf
|
@ -159,6 +159,8 @@ class Warnings:
|
||||||
"http://spacy.io/usage/v3#jupyter-notebook-gpu")
|
"http://spacy.io/usage/v3#jupyter-notebook-gpu")
|
||||||
W112 = ("The model specified to use for initial vectors ({name}) has no "
|
W112 = ("The model specified to use for initial vectors ({name}) has no "
|
||||||
"vectors. This is almost certainly a mistake.")
|
"vectors. This is almost certainly a mistake.")
|
||||||
|
W113 = ("Sourced component '{name}' may not work as expected: source "
|
||||||
|
"vectors are not identical to current pipeline vectors.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
@ -651,8 +653,8 @@ class Errors:
|
||||||
"returned the initialized nlp object instead?")
|
"returned the initialized nlp object instead?")
|
||||||
E944 = ("Can't copy pipeline component '{name}' from source '{model}': "
|
E944 = ("Can't copy pipeline component '{name}' from source '{model}': "
|
||||||
"not found in pipeline. Available components: {opts}")
|
"not found in pipeline. Available components: {opts}")
|
||||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
E945 = ("Can't copy pipeline component '{name}' from source. Expected "
|
||||||
"nlp object, but got: {source}")
|
"loaded nlp object, but got: {source}")
|
||||||
E947 = ("`Matcher.add` received invalid `greedy` argument: expected "
|
E947 = ("`Matcher.add` received invalid `greedy` argument: expected "
|
||||||
"a string value from {expected} but got: '{arg}'")
|
"a string value from {expected} but got: '{arg}'")
|
||||||
E948 = ("`Matcher.add` received invalid 'patterns' argument: expected "
|
E948 = ("`Matcher.add` received invalid 'patterns' argument: expected "
|
||||||
|
|
|
@ -682,9 +682,14 @@ class Language:
|
||||||
name (str): Optional alternative name to use in current pipeline.
|
name (str): Optional alternative name to use in current pipeline.
|
||||||
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
||||||
"""
|
"""
|
||||||
# TODO: handle errors and mismatches (vectors etc.)
|
# Check source type
|
||||||
if not isinstance(source, self.__class__):
|
if not isinstance(source, Language):
|
||||||
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
||||||
|
# Check vectors, with faster checks first
|
||||||
|
if self.vocab.vectors.shape != source.vocab.vectors.shape or \
|
||||||
|
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \
|
||||||
|
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes():
|
||||||
|
util.logger.warning(Warnings.W113.format(name=source_name))
|
||||||
if not source_name in source.component_names:
|
if not source_name in source.component_names:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
Errors.E944.format(
|
Errors.E944.format(
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import mock
|
||||||
|
import logging
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.de import German
|
from spacy.lang.de import German
|
||||||
|
@ -402,6 +404,38 @@ def test_pipe_factories_from_source():
|
||||||
nlp.add_pipe("custom", source=source_nlp)
|
nlp.add_pipe("custom", source=source_nlp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_factories_from_source_language_subclass():
|
||||||
|
class CustomEnglishDefaults(English.Defaults):
|
||||||
|
stop_words = set(["custom", "stop"])
|
||||||
|
|
||||||
|
@registry.languages("custom_en")
|
||||||
|
class CustomEnglish(English):
|
||||||
|
lang = "custom_en"
|
||||||
|
Defaults = CustomEnglishDefaults
|
||||||
|
|
||||||
|
source_nlp = English()
|
||||||
|
source_nlp.add_pipe("tagger")
|
||||||
|
|
||||||
|
# custom subclass
|
||||||
|
nlp = CustomEnglish()
|
||||||
|
nlp.add_pipe("tagger", source=source_nlp)
|
||||||
|
assert "tagger" in nlp.pipe_names
|
||||||
|
|
||||||
|
# non-subclass
|
||||||
|
nlp = German()
|
||||||
|
nlp.add_pipe("tagger", source=source_nlp)
|
||||||
|
assert "tagger" in nlp.pipe_names
|
||||||
|
|
||||||
|
# mismatched vectors
|
||||||
|
nlp = English()
|
||||||
|
nlp.vocab.vectors.resize((1, 4))
|
||||||
|
nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4])
|
||||||
|
logger = logging.getLogger("spacy")
|
||||||
|
with mock.patch.object(logger, "warning") as mock_warning:
|
||||||
|
nlp.add_pipe("tagger", source=source_nlp)
|
||||||
|
mock_warning.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_pipe_factories_from_source_custom():
|
def test_pipe_factories_from_source_custom():
|
||||||
"""Test adding components from a source model with custom components."""
|
"""Test adding components from a source model with custom components."""
|
||||||
name = "test_pipe_factories_from_source_custom"
|
name = "test_pipe_factories_from_source_custom"
|
||||||
|
|
Loading…
Reference in New Issue