mirror of https://github.com/explosion/spaCy.git
Improve checks for sourced components (#7490)
* Improve checks for sourced components * Remove language class checks * Convert python warning to logger warning * Remove unused warning * Fix formatting
This commit is contained in:
parent
05bdbe28bb
commit
1ad646cbcf
|
@ -159,6 +159,8 @@ class Warnings:
|
|||
"http://spacy.io/usage/v3#jupyter-notebook-gpu")
|
||||
W112 = ("The model specified to use for initial vectors ({name}) has no "
|
||||
"vectors. This is almost certainly a mistake.")
|
||||
W113 = ("Sourced component '{name}' may not work as expected: source "
|
||||
"vectors are not identical to current pipeline vectors.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -651,8 +653,8 @@ class Errors:
|
|||
"returned the initialized nlp object instead?")
|
||||
E944 = ("Can't copy pipeline component '{name}' from source '{model}': "
|
||||
"not found in pipeline. Available components: {opts}")
|
||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||
"nlp object, but got: {source}")
|
||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected "
|
||||
"loaded nlp object, but got: {source}")
|
||||
E947 = ("`Matcher.add` received invalid `greedy` argument: expected "
|
||||
"a string value from {expected} but got: '{arg}'")
|
||||
E948 = ("`Matcher.add` received invalid 'patterns' argument: expected "
|
||||
|
|
|
@ -682,9 +682,14 @@ class Language:
|
|||
name (str): Optional alternative name to use in current pipeline.
|
||||
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
||||
"""
|
||||
# TODO: handle errors and mismatches (vectors etc.)
|
||||
if not isinstance(source, self.__class__):
|
||||
# Check source type
|
||||
if not isinstance(source, Language):
|
||||
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
||||
# Check vectors, with faster checks first
|
||||
if self.vocab.vectors.shape != source.vocab.vectors.shape or \
|
||||
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \
|
||||
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes():
|
||||
util.logger.warning(Warnings.W113.format(name=source_name))
|
||||
if not source_name in source.component_names:
|
||||
raise KeyError(
|
||||
Errors.E944.format(
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import pytest
|
||||
import mock
|
||||
import logging
|
||||
from spacy.language import Language
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.de import German
|
||||
|
@ -402,6 +404,38 @@ def test_pipe_factories_from_source():
|
|||
nlp.add_pipe("custom", source=source_nlp)
|
||||
|
||||
|
||||
def test_pipe_factories_from_source_language_subclass():
|
||||
class CustomEnglishDefaults(English.Defaults):
|
||||
stop_words = set(["custom", "stop"])
|
||||
|
||||
@registry.languages("custom_en")
|
||||
class CustomEnglish(English):
|
||||
lang = "custom_en"
|
||||
Defaults = CustomEnglishDefaults
|
||||
|
||||
source_nlp = English()
|
||||
source_nlp.add_pipe("tagger")
|
||||
|
||||
# custom subclass
|
||||
nlp = CustomEnglish()
|
||||
nlp.add_pipe("tagger", source=source_nlp)
|
||||
assert "tagger" in nlp.pipe_names
|
||||
|
||||
# non-subclass
|
||||
nlp = German()
|
||||
nlp.add_pipe("tagger", source=source_nlp)
|
||||
assert "tagger" in nlp.pipe_names
|
||||
|
||||
# mismatched vectors
|
||||
nlp = English()
|
||||
nlp.vocab.vectors.resize((1, 4))
|
||||
nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4])
|
||||
logger = logging.getLogger("spacy")
|
||||
with mock.patch.object(logger, "warning") as mock_warning:
|
||||
nlp.add_pipe("tagger", source=source_nlp)
|
||||
mock_warning.assert_called()
|
||||
|
||||
|
||||
def test_pipe_factories_from_source_custom():
|
||||
"""Test adding components from a source model with custom components."""
|
||||
name = "test_pipe_factories_from_source_custom"
|
||||
|
|
Loading…
Reference in New Issue