Preserve user data for DependencyMatcher on spans (#7528)

* Preserve user data for DependencyMatcher on spans

* Clean underscore in test

* Modify test to use extensions stored in user data
This commit is contained in:
Adriane Boyd 2021-03-30 12:26:22 +02:00 committed by GitHub
parent 921feee092
commit 348d1829c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 2 deletions

View File

@ -299,7 +299,7 @@ cdef class DependencyMatcher:
if isinstance(doclike, Doc):
doc = doclike
elif isinstance(doclike, Span):
doc = doclike.as_doc()
doc = doclike.as_doc(copy_user_data=True)
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))

View File

@ -4,7 +4,9 @@ import re
import copy
from mock import Mock
from spacy.matcher import DependencyMatcher
from spacy.tokens import Doc
from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore # noqa: F401
@pytest.fixture
@ -344,3 +346,26 @@ def test_dependency_matcher_long_matches(en_vocab, doc):
matcher = DependencyMatcher(en_vocab)
with pytest.raises(ValueError):
matcher.add("pattern", [pattern])
@pytest.mark.usefixtures("clean_underscore")
def test_dependency_matcher_span_user_data(en_tokenizer):
doc = en_tokenizer("a b c d e")
for token in doc:
token.head = doc[0]
token.dep_ = "a"
get_is_c = lambda token: token.text in ("c",)
Token.set_extension("is_c", default=False)
doc[2]._.is_c = True
pattern = [
{"RIGHT_ID": "c", "RIGHT_ATTRS": {"_": {"is_c": True}}},
]
matcher = DependencyMatcher(en_tokenizer.vocab)
matcher.add("C", [pattern])
doc_matches = matcher(doc)
offset = 1
span_matches = matcher(doc[offset:])
for doc_match, span_match in zip(sorted(doc_matches), sorted(span_matches)):
assert doc_match[0] == span_match[0]
for doc_t_i, span_t_i in zip(doc_match[1], span_match[1]):
assert doc_t_i == span_t_i + offset