mirror of https://github.com/explosion/spaCy.git
parent
78bb9ff5e0
commit
69fb4bedf2
|
@ -213,28 +213,28 @@ cdef class Matcher:
|
|||
else:
|
||||
yield doc
|
||||
|
||||
def __call__(self, object obj):
|
||||
def __call__(self, object doc_or_span):
|
||||
"""Find all token sequences matching the supplied pattern.
|
||||
|
||||
obj (Doc / Span): The document to match over.
|
||||
doc_or_span (Doc or Span): The document to match over.
|
||||
RETURNS (list): A list of `(key, start, end)` tuples,
|
||||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
if isinstance(obj, Doc):
|
||||
doc = obj
|
||||
if isinstance(doc_or_span, Doc):
|
||||
doc = doc_or_span
|
||||
length = len(doc)
|
||||
elif isinstance(obj, Span):
|
||||
doc = obj.doc
|
||||
length = obj.end - obj.start
|
||||
elif isinstance(doc_or_span, Span):
|
||||
doc = doc_or_span.doc
|
||||
length = doc_or_span.end - doc_or_span.start
|
||||
else:
|
||||
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(obj).__name__))
|
||||
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doc_or_span).__name__))
|
||||
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||
and not doc.is_tagged:
|
||||
raise ValueError(Errors.E155.format())
|
||||
if DEP in self._seen_attrs and not doc.is_parsed:
|
||||
raise ValueError(Errors.E156.format())
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), obj, length,
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length,
|
||||
extensions=self._extensions, predicates=self._extra_predicates)
|
||||
for i, (key, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(key, None)
|
||||
|
@ -257,7 +257,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
|
|||
return matcher
|
||||
|
||||
|
||||
cdef find_matches(TokenPatternC** patterns, int n, object obj, int length, extensions=None, predicates=tuple()):
|
||||
cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()):
|
||||
"""Find matches in a doc, with a compiled array of patterns. Matches are
|
||||
returned as a list of (id, start, end) tuples.
|
||||
|
||||
|
@ -286,7 +286,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object obj, int length, exten
|
|||
else:
|
||||
nr_extra_attr = 0
|
||||
extra_attr_values = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
||||
for i, token in enumerate(obj):
|
||||
for i, token in enumerate(doc_or_span):
|
||||
for name, index in extensions.items():
|
||||
value = token._.get(name)
|
||||
if isinstance(value, basestring):
|
||||
|
@ -298,7 +298,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object obj, int length, exten
|
|||
for j in range(n):
|
||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||
transition_states(states, matches, predicate_cache,
|
||||
obj[i], extra_attr_values, predicates)
|
||||
doc_or_span[i], extra_attr_values, predicates)
|
||||
extra_attr_values += nr_extra_attr
|
||||
predicate_cache += len(predicates)
|
||||
# Handle matches that end in 0-width patterns
|
||||
|
|
Loading…
Reference in New Issue