Support match alignments (#7321)

* Support match alignments

* change naming from match_alignments to with_alignments, add conditional flow if with_alignments is given, validate with_alignments, add related test case

* remove added errors, utilize bint type, cleanup whitespace

* fix no new line in end of file

* Minor formatting

* Skip alignments processing if as_spans is set

* Add with_alignments to Matcher API docs

* Update website/docs/api/matcher.md

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
broaddeep 2021-04-08 17:10:14 +09:00 committed by GitHub
parent ff84075839
commit ee159b8543
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 321 additions and 28 deletions

106
.github/contributors/broaddeep.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI GmbH](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [x] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect to my
contributions.
* [ ] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------- |
| Name | Dongjun Park |
| Company name (if applicable) | |
| Title or role (if applicable) | |
| Date | 2021-03-06 |
| GitHub username | broaddeep |
| Website (optional) | |

View File

@ -46,6 +46,12 @@ cdef struct TokenPatternC:
int32_t nr_py int32_t nr_py
quantifier_t quantifier quantifier_t quantifier
hash_t key hash_t key
int32_t token_idx
cdef struct MatchAlignmentC:
int32_t token_idx
int32_t length
cdef struct PatternStateC: cdef struct PatternStateC:

View File

@ -196,7 +196,7 @@ cdef class Matcher:
else: else:
yield doc yield doc
def __call__(self, object doclike, *, as_spans=False, allow_missing=False): def __call__(self, object doclike, *, as_spans=False, allow_missing=False, with_alignments=False):
"""Find all token sequences matching the supplied pattern. """Find all token sequences matching the supplied pattern.
doclike (Doc or Span): The document to match over. doclike (Doc or Span): The document to match over.
@ -204,10 +204,16 @@ cdef class Matcher:
start, end) tuples. start, end) tuples.
allow_missing (bool): Whether to skip checks for missing annotation for allow_missing (bool): Whether to skip checks for missing annotation for
attributes included in patterns. Defaults to False. attributes included in patterns. Defaults to False.
with_alignments (bool): Return match alignment information, which is
`List[int]` with length of matched span. Each entry denotes the
corresponding index of token pattern. If as_spans is set to True,
this setting is ignored.
RETURNS (list): A list of `(match_id, start, end)` tuples, RETURNS (list): A list of `(match_id, start, end)` tuples,
describing the matches. A match tuple describes a span describing the matches. A match tuple describes a span
`doc[start:end]`. The `match_id` is an integer. If as_spans is set `doc[start:end]`. The `match_id` is an integer. If as_spans is set
to True, a list of Span objects is returned. to True, a list of Span objects is returned.
If with_alignments is set to True and as_spans is set to False,
A list of `(match_id, start, end, alignments)` tuples is returned.
""" """
if isinstance(doclike, Doc): if isinstance(doclike, Doc):
doc = doclike doc = doclike
@ -217,6 +223,9 @@ cdef class Matcher:
length = doclike.end - doclike.start length = doclike.end - doclike.start
else: else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
# Skip alignments calculations if as_spans is set
if as_spans:
with_alignments = False
cdef Pool tmp_pool = Pool() cdef Pool tmp_pool = Pool()
if not allow_missing: if not allow_missing:
for attr in (TAG, POS, MORPH, LEMMA, DEP): for attr in (TAG, POS, MORPH, LEMMA, DEP):
@ -232,18 +241,20 @@ cdef class Matcher:
error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr)) error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr))
raise ValueError(error_msg) raise ValueError(error_msg)
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates) extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
final_matches = [] final_matches = []
pairs_by_id = {} pairs_by_id = {}
# For each key, either add all matches, or only the filtered, non-overlapping ones # For each key, either add all matches, or only the filtered,
for (key, start, end) in matches: # non-overlapping ones this `match` can be either (start, end) or
# (start, end, alignments) depending on `with_alignments=` option.
for key, *match in matches:
span_filter = self._filter.get(key) span_filter = self._filter.get(key)
if span_filter is not None: if span_filter is not None:
pairs = pairs_by_id.get(key, []) pairs = pairs_by_id.get(key, [])
pairs.append((start,end)) pairs.append(match)
pairs_by_id[key] = pairs pairs_by_id[key] = pairs
else: else:
final_matches.append((key, start, end)) final_matches.append((key, *match))
matched = <char*>tmp_pool.alloc(length, sizeof(char)) matched = <char*>tmp_pool.alloc(length, sizeof(char))
empty = <char*>tmp_pool.alloc(length, sizeof(char)) empty = <char*>tmp_pool.alloc(length, sizeof(char))
for key, pairs in pairs_by_id.items(): for key, pairs in pairs_by_id.items():
@ -255,14 +266,18 @@ cdef class Matcher:
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else: else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter)) raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for (start, end) in sorted_pairs: for match in sorted_pairs:
start, end = match[:2]
assert 0 <= start < end # Defend against segfaults assert 0 <= start < end # Defend against segfaults
span_len = end-start span_len = end-start
# If no tokens in the span have matched # If no tokens in the span have matched
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0: if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
final_matches.append((key, start, end)) final_matches.append((key, *match))
# Mark tokens that have matched # Mark tokens that have matched
memset(&matched[start], 1, span_len * sizeof(matched[0])) memset(&matched[start], 1, span_len * sizeof(matched[0]))
if with_alignments:
final_matches_with_alignments = final_matches
final_matches = [(key, start, end) for key, start, end, alignments in final_matches]
# perform the callbacks on the filtered set of results # perform the callbacks on the filtered set of results
for i, (key, start, end) in enumerate(final_matches): for i, (key, start, end) in enumerate(final_matches):
on_match = self._callbacks.get(key, None) on_match = self._callbacks.get(key, None)
@ -270,6 +285,22 @@ cdef class Matcher:
on_match(self, doc, i, final_matches) on_match(self, doc, i, final_matches)
if as_spans: if as_spans:
return [Span(doc, start, end, label=key) for key, start, end in final_matches] return [Span(doc, start, end, label=key) for key, start, end in final_matches]
elif with_alignments:
# convert alignments List[Dict[str, int]] --> List[int]
final_matches = []
# when multiple alignment (belongs to the same length) is found,
# keeps the alignment that has largest token_idx
for key, start, end, alignments in final_matches_with_alignments:
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
alignments = [0] * (end-start)
for align in sorted_alignments:
if align['length'] >= end-start:
continue
# Since alignments are sorted in order of (length, token_idx)
# this overwrites smaller token_idx when they have same length.
alignments[align['length']] = align['token_idx']
final_matches.append((key, start, end, alignments))
return final_matches
else: else:
return final_matches return final_matches
@ -288,9 +319,9 @@ def unpickle_matcher(vocab, patterns, callbacks):
return matcher return matcher
cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple()): cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple(), bint with_alignments=0):
"""Find matches in a doc, with a compiled array of patterns. Matches are """Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples. returned as a list of (id, start, end) tuples or (id, start, end, alignments) tuples (if with_alignments != 0)
To augment the compiled patterns, we optionally also take two Python lists. To augment the compiled patterns, we optionally also take two Python lists.
@ -302,6 +333,8 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
""" """
cdef vector[PatternStateC] states cdef vector[PatternStateC] states
cdef vector[MatchC] matches cdef vector[MatchC] matches
cdef vector[vector[MatchAlignmentC]] align_states
cdef vector[vector[MatchAlignmentC]] align_matches
cdef PatternStateC state cdef PatternStateC state
cdef int i, j, nr_extra_attr cdef int i, j, nr_extra_attr
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -328,12 +361,14 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
for i in range(length): for i in range(length):
for j in range(n): for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, predicate_cache, if with_alignments != 0:
doclike[i], extra_attr_values, predicates) align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache,
doclike[i], extra_attr_values, predicates, with_alignments)
extra_attr_values += nr_extra_attr extra_attr_values += nr_extra_attr
predicate_cache += len(predicates) predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns
finish_states(matches, states) finish_states(matches, states, align_matches, align_states, with_alignments)
seen = set() seen = set()
for i in range(matches.size()): for i in range(matches.size()):
match = ( match = (
@ -346,16 +381,22 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
# first .?, or the second .? -- it doesn't matter, it's just one match. # first .?, or the second .? -- it doesn't matter, it's just one match.
# Skip 0-length matches. (TODO: fix algorithm) # Skip 0-length matches. (TODO: fix algorithm)
if match not in seen and matches[i].length > 0: if match not in seen and matches[i].length > 0:
output.append(match) if with_alignments != 0:
# since the length of align_matches equals to that of match, we can share same 'i'
output.append(match + (align_matches[i],))
else:
output.append(match)
seen.add(match) seen.add(match)
return output return output
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates, int8_t* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates) except *: Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *:
cdef int q = 0 cdef int q = 0
cdef vector[PatternStateC] new_states cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states
cdef int nr_predicate = len(py_predicates) cdef int nr_predicate = len(py_predicates)
for i in range(states.size()): for i in range(states.size()):
if states[i].pattern.nr_py >= 1: if states[i].pattern.nr_py >= 1:
@ -370,23 +411,39 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
# it in the states list, because q doesn't advance. # it in the states list, because q doesn't advance.
state = states[i] state = states[i]
states[q] = state states[q] = state
# Separate from states, performance is guaranteed for users who only need basic options (without alignments).
# `align_states` always corresponds to `states` 1:1.
if with_alignments != 0:
align_state = align_states[i]
align_states[q] = align_state
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND): while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
# Update alignment before the transition of current state
# 'MatchAlignmentC' maps 'original token index of current pattern' to 'current matching length'
if with_alignments != 0:
align_states[q].push_back(MatchAlignmentC(states[q].pattern.token_idx, states[q].length))
if action == RETRY_EXTEND: if action == RETRY_EXTEND:
# This handles the 'extend' # This handles the 'extend'
new_states.push_back( new_states.push_back(
PatternStateC(pattern=states[q].pattern, start=state.start, PatternStateC(pattern=states[q].pattern, start=state.start,
length=state.length+1)) length=state.length+1))
if with_alignments != 0:
align_new_states.push_back(align_states[q])
if action == RETRY_ADVANCE: if action == RETRY_ADVANCE:
# This handles the 'advance' # This handles the 'advance'
new_states.push_back( new_states.push_back(
PatternStateC(pattern=states[q].pattern+1, start=state.start, PatternStateC(pattern=states[q].pattern+1, start=state.start,
length=state.length+1)) length=state.length+1))
if with_alignments != 0:
align_new_states.push_back(align_states[q])
states[q].pattern += 1 states[q].pattern += 1
if states[q].pattern.nr_py != 0: if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates, update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates) states[q].pattern, token, py_predicates)
action = get_action(states[q], token.c, extra_attrs, action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates) cached_py_predicates)
# Update alignment before the transition of current state
if with_alignments != 0:
align_states[q].push_back(MatchAlignmentC(states[q].pattern.token_idx, states[q].length))
if action == REJECT: if action == REJECT:
pass pass
elif action == ADVANCE: elif action == ADVANCE:
@ -399,29 +456,50 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1)) length=state.length+1))
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
elif action == MATCH_DOUBLE: elif action == MATCH_DOUBLE:
# push match without last token if length > 0 # push match without last token if length > 0
if state.length > 0: if state.length > 0:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
length=state.length)) length=state.length))
# MATCH_DOUBLE emits matches twice,
# add one more to align_matches in order to keep 1:1 relationship
if with_alignments != 0:
align_matches.push_back(align_states[q])
# push match with last token # push match with last token
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1)) length=state.length+1))
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
elif action == MATCH_REJECT: elif action == MATCH_REJECT:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
length=state.length)) length=state.length))
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
elif action == MATCH_EXTEND: elif action == MATCH_EXTEND:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
length=state.length)) length=state.length))
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
states[q].length += 1 states[q].length += 1
q += 1 q += 1
states.resize(q) states.resize(q)
for i in range(new_states.size()): for i in range(new_states.size()):
states.push_back(new_states[i]) states.push_back(new_states[i])
# `align_states` always corresponds to `states` 1:1
if with_alignments != 0:
align_states.resize(q)
for i in range(align_new_states.size()):
align_states.push_back(align_new_states[i])
cdef int update_predicate_cache(int8_t* cache, cdef int update_predicate_cache(int8_t* cache,
@ -444,15 +522,27 @@ cdef int update_predicate_cache(int8_t* cache,
raise ValueError(Errors.E125.format(value=result)) raise ValueError(Errors.E125.format(value=result))
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *: cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
vector[vector[MatchAlignmentC]]& align_matches,
vector[vector[MatchAlignmentC]]& align_states,
bint with_alignments) except *:
"""Handle states that end in zero-width patterns.""" """Handle states that end in zero-width patterns."""
cdef PatternStateC state cdef PatternStateC state
cdef vector[MatchAlignmentC] align_state
for i in range(states.size()): for i in range(states.size()):
state = states[i] state = states[i]
if with_alignments != 0:
align_state = align_states[i]
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
# Update alignment before the transition of current state
if with_alignments != 0:
align_state.push_back(MatchAlignmentC(state.pattern.token_idx, state.length))
is_final = get_is_final(state) is_final = get_is_final(state)
if is_final: if is_final:
ent_id = get_ent_id(state.pattern) ent_id = get_ent_id(state.pattern)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_state)
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, length=state.length)) MatchC(pattern_id=ent_id, start=state.start, length=state.length))
break break
@ -607,7 +697,7 @@ cdef int8_t get_quantifier(PatternStateC state) nogil:
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i, index cdef int i, index
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs): for i, (quantifier, spec, extensions, predicates, token_idx) in enumerate(token_specs):
pattern[i].quantifier = quantifier pattern[i].quantifier = quantifier
# Ensure attrs refers to a null pointer if nr_attr == 0 # Ensure attrs refers to a null pointer if nr_attr == 0
if len(spec) > 0: if len(spec) > 0:
@ -628,6 +718,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
pattern[i].py_predicates[j] = index pattern[i].py_predicates[j] = index
pattern[i].nr_py = len(predicates) pattern[i].nr_py = len(predicates)
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0) pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
pattern[i].token_idx = token_idx
i = len(token_specs) i = len(token_specs)
# Use quantifier to identify final ID pattern node (rather than previous # Use quantifier to identify final ID pattern node (rather than previous
# uninitialized quantifier == 0/ZERO + nr_attr == 0 + non-zero-length attrs) # uninitialized quantifier == 0/ZERO + nr_attr == 0 + non-zero-length attrs)
@ -638,6 +729,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
pattern[i].nr_attr = 1 pattern[i].nr_attr = 1
pattern[i].nr_extra_attr = 0 pattern[i].nr_extra_attr = 0
pattern[i].nr_py = 0 pattern[i].nr_py = 0
pattern[i].token_idx = -1
return pattern return pattern
@ -655,7 +747,7 @@ def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
"""This function interprets the pattern, converting the various bits of """This function interprets the pattern, converting the various bits of
syntactic sugar before we compile it into a struct with init_pattern. syntactic sugar before we compile it into a struct with init_pattern.
We need to split the pattern up into three parts: We need to split the pattern up into four parts:
* Normal attribute/value pairs, which are stored on either the token or lexeme, * Normal attribute/value pairs, which are stored on either the token or lexeme,
can be handled directly. can be handled directly.
* Extension attributes are handled specially, as we need to prefetch the * Extension attributes are handled specially, as we need to prefetch the
@ -664,13 +756,14 @@ def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
functions and store them. So we store these specially as well. functions and store them. So we store these specially as well.
* Extension attributes that have extra predicates are stored within the * Extension attributes that have extra predicates are stored within the
extra_predicates. extra_predicates.
* Token index that this pattern belongs to.
""" """
tokens = [] tokens = []
string_store = vocab.strings string_store = vocab.strings
for spec in token_specs: for token_idx, spec in enumerate(token_specs):
if not spec: if not spec:
# Signifier for 'any token' # Signifier for 'any token'
tokens.append((ONE, [(NULL_ATTR, 0)], [], [])) tokens.append((ONE, [(NULL_ATTR, 0)], [], [], token_idx))
continue continue
if not isinstance(spec, dict): if not isinstance(spec, dict):
raise ValueError(Errors.E154.format()) raise ValueError(Errors.E154.format())
@ -679,7 +772,7 @@ def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
extensions = _get_extensions(spec, string_store, extensions_table) extensions = _get_extensions(spec, string_store, extensions_table)
predicates = _get_extra_predicates(spec, extra_predicates, vocab) predicates = _get_extra_predicates(spec, extra_predicates, vocab)
for op in ops: for op in ops:
tokens.append((op, list(attr_values), list(extensions), list(predicates))) tokens.append((op, list(attr_values), list(extensions), list(predicates), token_idx))
return tokens return tokens

View File

@ -204,3 +204,90 @@ def test_matcher_remove():
# removing again should throw an error # removing again should throw an error
with pytest.raises(ValueError): with pytest.raises(ValueError):
matcher.remove("Rule") matcher.remove("Rule")
def test_matcher_with_alignments_greedy_longest(en_vocab):
cases = [
("aaab", "a* b", [0, 0, 0, 1]),
("baab", "b a* b", [0, 1, 1, 2]),
("aaab", "a a a b", [0, 1, 2, 3]),
("aaab", "a+ b", [0, 0, 0, 1]),
("aaba", "a+ b a+", [0, 0, 1, 2]),
("aabaa", "a+ b a+", [0, 0, 1, 2, 2]),
("aaba", "a+ b a*", [0, 0, 1, 2]),
("aaaa", "a*", [0, 0, 0, 0]),
("baab", "b a* b b*", [0, 1, 1, 2]),
("aabb", "a* b* a*", [0, 0, 1, 1]),
("aaab", "a+ a+ a b", [0, 1, 2, 3]),
("aaab", "a+ a+ a+ b", [0, 1, 2, 3]),
("aaab", "a+ a a b", [0, 1, 2, 3]),
("aaab", "a+ a a", [0, 1, 2]),
("aaab", "a+ a a?", [0, 1, 2]),
("aaaa", "a a a a a?", [0, 1, 2, 3]),
("aaab", "a+ a b", [0, 0, 1, 2]),
("aaab", "a+ a+ b", [0, 0, 1, 2]),
]
for string, pattern_str, result in cases:
matcher = Matcher(en_vocab)
doc = Doc(matcher.vocab, words=list(string))
pattern = []
for part in pattern_str.split():
if part.endswith("+"):
pattern.append({"ORTH": part[0], "OP": "+"})
elif part.endswith("*"):
pattern.append({"ORTH": part[0], "OP": "*"})
elif part.endswith("?"):
pattern.append({"ORTH": part[0], "OP": "?"})
else:
pattern.append({"ORTH": part})
matcher.add("PATTERN", [pattern], greedy="LONGEST")
matches = matcher(doc, with_alignments=True)
n_matches = len(matches)
_, s, e, expected = matches[0]
assert expected == result, (string, pattern_str, s, e, n_matches)
def test_matcher_with_alignments_nongreedy(en_vocab):
cases = [
(0, "aaab", "a* b", [[0, 1], [0, 0, 1], [0, 0, 0, 1], [1]]),
(1, "baab", "b a* b", [[0, 1, 1, 2]]),
(2, "aaab", "a a a b", [[0, 1, 2, 3]]),
(3, "aaab", "a+ b", [[0, 1], [0, 0, 1], [0, 0, 0, 1]]),
(4, "aaba", "a+ b a+", [[0, 1, 2], [0, 0, 1, 2]]),
(5, "aabaa", "a+ b a+", [[0, 1, 2], [0, 0, 1, 2], [0, 0, 1, 2, 2], [0, 1, 2, 2] ]),
(6, "aaba", "a+ b a*", [[0, 1], [0, 0, 1], [0, 0, 1, 2], [0, 1, 2]]),
(7, "aaaa", "a*", [[0], [0, 0], [0, 0, 0], [0, 0, 0, 0]]),
(8, "baab", "b a* b b*", [[0, 1, 1, 2]]),
(9, "aabb", "a* b* a*", [[1], [2], [2, 2], [0, 1], [0, 0, 1], [0, 0, 1, 1], [0, 1, 1], [1, 1]]),
(10, "aaab", "a+ a+ a b", [[0, 1, 2, 3]]),
(11, "aaab", "a+ a+ a+ b", [[0, 1, 2, 3]]),
(12, "aaab", "a+ a a b", [[0, 1, 2, 3]]),
(13, "aaab", "a+ a a", [[0, 1, 2]]),
(14, "aaab", "a+ a a?", [[0, 1], [0, 1, 2]]),
(15, "aaaa", "a a a a a?", [[0, 1, 2, 3]]),
(16, "aaab", "a+ a b", [[0, 1, 2], [0, 0, 1, 2]]),
(17, "aaab", "a+ a+ b", [[0, 1, 2], [0, 0, 1, 2]]),
]
for case_id, string, pattern_str, results in cases:
matcher = Matcher(en_vocab)
doc = Doc(matcher.vocab, words=list(string))
pattern = []
for part in pattern_str.split():
if part.endswith("+"):
pattern.append({"ORTH": part[0], "OP": "+"})
elif part.endswith("*"):
pattern.append({"ORTH": part[0], "OP": "*"})
elif part.endswith("?"):
pattern.append({"ORTH": part[0], "OP": "?"})
else:
pattern.append({"ORTH": part})
matcher.add("PATTERN", [pattern])
matches = matcher(doc, with_alignments=True)
n_matches = len(matches)
for _, s, e, expected in matches:
assert expected in results, (case_id, string, pattern_str, s, e, n_matches)
assert len(expected) == e - s

View File

@ -120,13 +120,14 @@ Find all token sequences matching the supplied patterns on the `Doc` or `Span`.
> matches = matcher(doc) > matches = matcher(doc)
> ``` > ```
| Name | Description | | Name | Description |
| ------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ---------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ | | `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `as_spans` <Tag variant="new">3</Tag> | Instead of tuples, return a list of [`Span`](/api/span) objects of the matches, with the `match_id` assigned as the span label. Defaults to `False`. ~~bool~~ | | `as_spans` <Tag variant="new">3</Tag> | Instead of tuples, return a list of [`Span`](/api/span) objects of the matches, with the `match_id` assigned as the span label. Defaults to `False`. ~~bool~~ |
| `allow_missing` <Tag variant="new">3</Tag> | Whether to skip checks for missing annotation for attributes included in patterns. Defaults to `False`. ~~bool~~ | | `allow_missing` <Tag variant="new">3</Tag> | Whether to skip checks for missing annotation for attributes included in patterns. Defaults to `False`. ~~bool~~ |
| **RETURNS** | A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end`]. The `match_id` is the ID of the added match pattern. If `as_spans` is set to `True`, a list of `Span` objects is returned instead. ~~Union[List[Tuple[int, int, int]], List[Span]]~~ | | `with_alignments` <Tag variant="new">3.1</Tag> | Return match alignment information as part of the match tuple as `List[int]` with the same length as the matched span. Each entry denotes the corresponding index of the token pattern. If `as_spans` is set to `True`, this setting is ignored. Defaults to `False`. ~~bool~~ |
| **RETURNS** | A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end`]. The `match_id` is the ID of the added match pattern. If `as_spans` is set to `True`, a list of `Span` objects is returned instead. ~~Union[List[Tuple[int, int, int]], List[Span]]~~ |
## Matcher.\_\_len\_\_ {#len tag="method" new="2"} ## Matcher.\_\_len\_\_ {#len tag="method" new="2"}