From 348d1829c7c3834a37e51f9081a7f2214053e8a2 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 30 Mar 2021 12:26:22 +0200 Subject: [PATCH] Preserve user data for DependencyMatcher on spans (#7528) * Preserve user data for DependencyMatcher on spans * Clean underscore in test * Modify test to use extensions stored in user data --- spacy/matcher/dependencymatcher.pyx | 2 +- .../tests/matcher/test_dependency_matcher.py | 27 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/spacy/matcher/dependencymatcher.pyx b/spacy/matcher/dependencymatcher.pyx index 4124696b3..0e601281a 100644 --- a/spacy/matcher/dependencymatcher.pyx +++ b/spacy/matcher/dependencymatcher.pyx @@ -299,7 +299,7 @@ cdef class DependencyMatcher: if isinstance(doclike, Doc): doc = doclike elif isinstance(doclike, Span): - doc = doclike.as_doc() + doc = doclike.as_doc(copy_user_data=True) else: raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) diff --git a/spacy/tests/matcher/test_dependency_matcher.py b/spacy/tests/matcher/test_dependency_matcher.py index a563ddaa2..fb9222aaa 100644 --- a/spacy/tests/matcher/test_dependency_matcher.py +++ b/spacy/tests/matcher/test_dependency_matcher.py @@ -4,7 +4,9 @@ import re import copy from mock import Mock from spacy.matcher import DependencyMatcher -from spacy.tokens import Doc +from spacy.tokens import Doc, Token + +from ..doc.test_underscore import clean_underscore # noqa: F401 @pytest.fixture @@ -344,3 +346,26 @@ def test_dependency_matcher_long_matches(en_vocab, doc): matcher = DependencyMatcher(en_vocab) with pytest.raises(ValueError): matcher.add("pattern", [pattern]) + + +@pytest.mark.usefixtures("clean_underscore") +def test_dependency_matcher_span_user_data(en_tokenizer): + doc = en_tokenizer("a b c d e") + for token in doc: + token.head = doc[0] + token.dep_ = "a" + get_is_c = lambda token: token.text in ("c",) + Token.set_extension("is_c", default=False) + doc[2]._.is_c = True + pattern = [ + {"RIGHT_ID": "c", "RIGHT_ATTRS": {"_": {"is_c": True}}}, + ] + matcher = DependencyMatcher(en_tokenizer.vocab) + matcher.add("C", [pattern]) + doc_matches = matcher(doc) + offset = 1 + span_matches = matcher(doc[offset:]) + for doc_match, span_match in zip(sorted(doc_matches), sorted(span_matches)): + assert doc_match[0] == span_match[0] + for doc_t_i, span_t_i in zip(doc_match[1], span_match[1]): + assert doc_t_i == span_t_i + offset