diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 4cfab915f..3d99f117a 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -213,28 +213,28 @@ cdef class Matcher: else: yield doc - def __call__(self, object doc_or_span): + def __call__(self, object obj): """Find all token sequences matching the supplied pattern. - doc_or_span (Doc or Span): The document to match over. + obj (Doc / Span): The document to match over. RETURNS (list): A list of `(key, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end]`. The `label_id` and `key` are both integers. """ - if isinstance(doc_or_span, Doc): - doc = doc_or_span + if isinstance(obj, Doc): + doc = obj length = len(doc) - elif isinstance(doc_or_span, Span): - doc = doc_or_span.doc - length = doc_or_span.end - doc_or_span.start + elif isinstance(obj, Span): + doc = obj.doc + length = obj.end - obj.start else: - raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doc_or_span).__name__)) + raise ValueError(Errors.E195.format(good="Doc or Span", got=type(obj).__name__)) if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ and not doc.is_tagged: raise ValueError(Errors.E155.format()) if DEP in self._seen_attrs and not doc.is_parsed: raise ValueError(Errors.E156.format()) - matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length, + matches = find_matches(&self.patterns[0], self.patterns.size(), obj, length, extensions=self._extensions, predicates=self._extra_predicates) for i, (key, start, end) in enumerate(matches): on_match = self._callbacks.get(key, None) @@ -257,7 +257,7 @@ def unpickle_matcher(vocab, patterns, callbacks): return matcher -cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()): +cdef find_matches(TokenPatternC** patterns, int n, object obj, int length, extensions=None, predicates=tuple()): """Find matches in a doc, with a compiled array of patterns. Matches are returned as a list of (id, start, end) tuples. @@ -286,7 +286,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int lengt else: nr_extra_attr = 0 extra_attr_values = mem.alloc(length, sizeof(attr_t)) - for i, token in enumerate(doc_or_span): + for i, token in enumerate(obj): for name, index in extensions.items(): value = token._.get(name) if isinstance(value, basestring): @@ -298,7 +298,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int lengt for j in range(n): states.push_back(PatternStateC(patterns[j], i, 0)) transition_states(states, matches, predicate_cache, - doc_or_span[i], extra_attr_values, predicates) + obj[i], extra_attr_values, predicates) extra_attr_values += nr_extra_attr predicate_cache += len(predicates) # Handle matches that end in 0-width patterns