From 1ca32d8f9c800eb36e912dc1fa7b173edf7f2c3c Mon Sep 17 00:00:00 2001 From: Paolo Arduin Date: Wed, 15 Apr 2020 13:51:33 +0200 Subject: [PATCH] Matcher support for Span as well as Doc (#5113) * Matcher support for Span, as well as Doc #5056 * Removes an import unused * Signed contributors agreement * Code optimization and better test * Add error message for bad Matcher call argument * Fix merging --- .github/contributors/paoloq.md | 2 +- spacy/errors.py | 1 + spacy/matcher/matcher.pyx | 36 ++++++++++++++----------- spacy/tests/matcher/test_matcher_api.py | 11 +++++++- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/.github/contributors/paoloq.md b/.github/contributors/paoloq.md index 84b28c8ef..0fac70c9a 100644 --- a/.github/contributors/paoloq.md +++ b/.github/contributors/paoloq.md @@ -5,7 +5,7 @@ This spaCy Contributor Agreement (**"SCA"**) is based on the The SCA applies to any contribution that you make to any product or project managed by us (the **"project"**), and sets out the intellectual property rights you grant to us in the contributed materials. The term **"us"** shall mean -[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term +[ExplosionAI GmbH](https://explosion.ai/legal). The term **"you"** shall mean the person or entity identified below. If you agree to be bound by these terms, fill in the information requested diff --git a/spacy/errors.py b/spacy/errors.py index ce26e63a4..b1cdb89ec 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -556,6 +556,7 @@ class Errors(object): "({new_dim}) is not the same as the current vector dimension " "({curr_dim}).") E194 = ("Unable to aligned mismatched text '{text}' and words '{words}'.") + E195 = ("Matcher can be called on {good} only, got {got}.") @add_codes diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 43480b46e..9e0fe2812 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -14,6 +14,7 @@ from ..typedefs cimport attr_t from ..structs cimport TokenC from ..vocab cimport Vocab from ..tokens.doc cimport Doc, get_token_attr +from ..tokens.span cimport Span from ..tokens.token cimport Token from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA @@ -211,22 +212,29 @@ cdef class Matcher: else: yield doc - def __call__(self, Doc doc): + def __call__(self, object doc_or_span): """Find all token sequences matching the supplied pattern. - doc (Doc): The document to match over. + doc_or_span (Doc or Span): The document to match over. RETURNS (list): A list of `(key, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end]`. The `label_id` and `key` are both integers. """ + if isinstance(doc_or_span, Doc): + doc = doc_or_span + length = len(doc) + elif isinstance(doc_or_span, Span): + doc = doc_or_span.doc + length = doc_or_span.end - doc_or_span.start + else: + raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doc_or_span).__name__)) if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ and not doc.is_tagged: raise ValueError(Errors.E155.format()) if DEP in self._seen_attrs and not doc.is_parsed: raise ValueError(Errors.E156.format()) - matches = find_matches(&self.patterns[0], self.patterns.size(), doc, - extensions=self._extensions, - predicates=self._extra_predicates) + matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length, + extensions=self._extensions, predicates=self._extra_predicates) for i, (key, start, end) in enumerate(matches): on_match = self._callbacks.get(key, None) if on_match is not None: @@ -248,9 +256,7 @@ def unpickle_matcher(vocab, patterns, callbacks): return matcher - -cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, - predicates=tuple()): +cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()): """Find matches in a doc, with a compiled array of patterns. Matches are returned as a list of (id, start, end) tuples. @@ -268,18 +274,18 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, cdef int i, j, nr_extra_attr cdef Pool mem = Pool() output = [] - if doc.length == 0: + if length == 0: # avoid any processing or mem alloc if the document is empty return output if len(predicates) > 0: - predicate_cache = mem.alloc(doc.length * len(predicates), sizeof(char)) + predicate_cache = mem.alloc(length * len(predicates), sizeof(char)) if extensions is not None and len(extensions) >= 1: nr_extra_attr = max(extensions.values()) + 1 - extra_attr_values = mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t)) + extra_attr_values = mem.alloc(length * nr_extra_attr, sizeof(attr_t)) else: nr_extra_attr = 0 - extra_attr_values = mem.alloc(doc.length, sizeof(attr_t)) - for i, token in enumerate(doc): + extra_attr_values = mem.alloc(length, sizeof(attr_t)) + for i, token in enumerate(doc_or_span): for name, index in extensions.items(): value = token._.get(name) if isinstance(value, basestring): @@ -287,11 +293,11 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, extra_attr_values[i * nr_extra_attr + index] = value # Main loop cdef int nr_predicate = len(predicates) - for i in range(doc.length): + for i in range(length): for j in range(n): states.push_back(PatternStateC(patterns[j], i, 0)) transition_states(states, matches, predicate_cache, - doc[i], extra_attr_values, predicates) + doc_or_span[i], extra_attr_values, predicates) extra_attr_values += nr_extra_attr predicate_cache += len(predicates) # Handle matches that end in 0-width patterns diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 2e5e64aac..0295ada82 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -6,7 +6,6 @@ import re from mock import Mock from spacy.matcher import Matcher, DependencyMatcher from spacy.tokens import Doc, Token - from ..doc.test_underscore import clean_underscore # noqa: F401 @@ -470,3 +469,13 @@ def test_matcher_callback(en_vocab): doc = Doc(en_vocab, words=["This", "is", "a", "test", "."]) matches = matcher(doc) mock.assert_called_once_with(matcher, doc, 0, matches) + + +def test_matcher_span(matcher): + text = "JavaScript is good but Java is better" + doc = Doc(matcher.vocab, words=text.split()) + span_js = doc[:3] + span_java = doc[4:] + assert len(matcher(doc)) == 2 + assert len(matcher(span_js)) == 1 + assert len(matcher(span_java)) == 1