mirror of https://github.com/explosion/spaCy.git
Matcher ID fixes (#4179)
* allow phrasematcher to link one match to multiple original patterns * small fix for defining ent_id in the matcher (anti-ghost prevention) * cleanup * formatting
This commit is contained in:
parent
f5d3afb1a3
commit
c417c380e3
|
@ -327,7 +327,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
states[q].length += 1
|
||||
q += 1
|
||||
else:
|
||||
ent_id = get_ent_id(&state.pattern[1])
|
||||
ent_id = get_ent_id(state.pattern)
|
||||
if action == MATCH:
|
||||
matches.push_back(
|
||||
MatchC(pattern_id=ent_id, start=state.start,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from libcpp.vector cimport vector
|
||||
|
||||
from ..typedefs cimport hash_t
|
||||
|
||||
ctypedef vector[hash_t] hash_vec
|
|
@ -2,6 +2,7 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
from cymem.cymem cimport Pool
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from preshed.maps cimport PreshMap
|
||||
|
@ -37,6 +38,7 @@ cdef class PhraseMatcher:
|
|||
cdef Vocab vocab
|
||||
cdef Matcher matcher
|
||||
cdef PreshMap phrase_ids
|
||||
cdef vector[hash_vec] ent_id_matrix
|
||||
cdef int max_length
|
||||
cdef attr_id_t attr
|
||||
cdef public object _callbacks
|
||||
|
@ -145,7 +147,23 @@ cdef class PhraseMatcher:
|
|||
lexeme.set_flag(tag, True)
|
||||
phrase_key[i] = lexeme.orth
|
||||
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
|
||||
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
||||
|
||||
if phrase_hash in self.phrase_ids:
|
||||
phrase_index = self.phrase_ids[phrase_hash]
|
||||
ent_id_list = self.ent_id_matrix[phrase_index]
|
||||
ent_id_list.append(ent_id)
|
||||
self.ent_id_matrix[phrase_index] = ent_id_list
|
||||
|
||||
else:
|
||||
ent_id_list = hash_vec(1)
|
||||
ent_id_list[0] = ent_id
|
||||
new_index = self.ent_id_matrix.size()
|
||||
if new_index == 0:
|
||||
# PreshMaps can not contain 0 as value, so storing a dummy at 0
|
||||
self.ent_id_matrix.push_back(hash_vec(0))
|
||||
new_index = 1
|
||||
self.ent_id_matrix.push_back(ent_id_list)
|
||||
self.phrase_ids.set(phrase_hash, <void*>new_index)
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||
|
@ -167,9 +185,10 @@ cdef class PhraseMatcher:
|
|||
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
||||
match_doc = Doc(self.vocab, words=words)
|
||||
for _, start, end in self.matcher(match_doc):
|
||||
ent_id = self.accept_match(match_doc, start, end)
|
||||
if ent_id is not None:
|
||||
matches.append((ent_id, start, end))
|
||||
ent_ids = self.accept_match(match_doc, start, end)
|
||||
if ent_ids is not None:
|
||||
for ent_id in ent_ids:
|
||||
matches.append((ent_id, start, end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
|
@ -216,11 +235,11 @@ cdef class PhraseMatcher:
|
|||
for i, j in enumerate(range(start, end)):
|
||||
phrase_key[i] = doc.c[j].lex.orth
|
||||
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
|
||||
ent_id = <hash_t>self.phrase_ids.get(key)
|
||||
if ent_id == 0:
|
||||
|
||||
ent_index = <hash_t>self.phrase_ids.get(key)
|
||||
if ent_index == 0:
|
||||
return None
|
||||
else:
|
||||
return ent_id
|
||||
return self.ent_id_matrix[ent_index]
|
||||
|
||||
def get_lex_value(self, Doc doc, int i):
|
||||
if self.attr == ORTH:
|
||||
|
|
|
@ -6,7 +6,6 @@ from spacy.matcher import PhraseMatcher
|
|||
from spacy.tokens import Doc
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_issue3972(en_vocab):
|
||||
"""Test that the PhraseMatcher returns duplicates for duplicate match IDs.
|
||||
"""
|
||||
|
@ -15,4 +14,10 @@ def test_issue3972(en_vocab):
|
|||
matcher.add("B", None, Doc(en_vocab, words=["New", "York"]))
|
||||
doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"])
|
||||
matches = matcher(doc)
|
||||
|
||||
assert len(matches) == 2
|
||||
|
||||
# We should have a match for each of the two rules
|
||||
found_ids = [en_vocab.strings[ent_id] for (ent_id, _, _) in matches]
|
||||
assert "A" in found_ids
|
||||
assert "B" in found_ids
|
||||
|
|
Loading…
Reference in New Issue