mirror of https://github.com/explosion/spaCy.git
Matcher ID fixes (#4179)
* allow phrasematcher to link one match to multiple original patterns * small fix for defining ent_id in the matcher (anti-ghost prevention) * cleanup * formatting
This commit is contained in:
parent
f5d3afb1a3
commit
c417c380e3
|
@ -327,7 +327,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
||||||
states[q].length += 1
|
states[q].length += 1
|
||||||
q += 1
|
q += 1
|
||||||
else:
|
else:
|
||||||
ent_id = get_ent_id(&state.pattern[1])
|
ent_id = get_ent_id(state.pattern)
|
||||||
if action == MATCH:
|
if action == MATCH:
|
||||||
matches.push_back(
|
matches.push_back(
|
||||||
MatchC(pattern_id=ent_id, start=state.start,
|
MatchC(pattern_id=ent_id, start=state.start,
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
|
from ..typedefs cimport hash_t
|
||||||
|
|
||||||
|
ctypedef vector[hash_t] hash_vec
|
|
@ -2,6 +2,7 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from libcpp.vector cimport vector
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
from preshed.maps cimport PreshMap
|
from preshed.maps cimport PreshMap
|
||||||
|
@ -37,6 +38,7 @@ cdef class PhraseMatcher:
|
||||||
cdef Vocab vocab
|
cdef Vocab vocab
|
||||||
cdef Matcher matcher
|
cdef Matcher matcher
|
||||||
cdef PreshMap phrase_ids
|
cdef PreshMap phrase_ids
|
||||||
|
cdef vector[hash_vec] ent_id_matrix
|
||||||
cdef int max_length
|
cdef int max_length
|
||||||
cdef attr_id_t attr
|
cdef attr_id_t attr
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
|
@ -145,7 +147,23 @@ cdef class PhraseMatcher:
|
||||||
lexeme.set_flag(tag, True)
|
lexeme.set_flag(tag, True)
|
||||||
phrase_key[i] = lexeme.orth
|
phrase_key[i] = lexeme.orth
|
||||||
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
|
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
|
||||||
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
|
||||||
|
if phrase_hash in self.phrase_ids:
|
||||||
|
phrase_index = self.phrase_ids[phrase_hash]
|
||||||
|
ent_id_list = self.ent_id_matrix[phrase_index]
|
||||||
|
ent_id_list.append(ent_id)
|
||||||
|
self.ent_id_matrix[phrase_index] = ent_id_list
|
||||||
|
|
||||||
|
else:
|
||||||
|
ent_id_list = hash_vec(1)
|
||||||
|
ent_id_list[0] = ent_id
|
||||||
|
new_index = self.ent_id_matrix.size()
|
||||||
|
if new_index == 0:
|
||||||
|
# PreshMaps can not contain 0 as value, so storing a dummy at 0
|
||||||
|
self.ent_id_matrix.push_back(hash_vec(0))
|
||||||
|
new_index = 1
|
||||||
|
self.ent_id_matrix.push_back(ent_id_list)
|
||||||
|
self.phrase_ids.set(phrase_hash, <void*>new_index)
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||||
|
@ -167,9 +185,10 @@ cdef class PhraseMatcher:
|
||||||
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
||||||
match_doc = Doc(self.vocab, words=words)
|
match_doc = Doc(self.vocab, words=words)
|
||||||
for _, start, end in self.matcher(match_doc):
|
for _, start, end in self.matcher(match_doc):
|
||||||
ent_id = self.accept_match(match_doc, start, end)
|
ent_ids = self.accept_match(match_doc, start, end)
|
||||||
if ent_id is not None:
|
if ent_ids is not None:
|
||||||
matches.append((ent_id, start, end))
|
for ent_id in ent_ids:
|
||||||
|
matches.append((ent_id, start, end))
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(ent_id)
|
on_match = self._callbacks.get(ent_id)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
|
@ -216,11 +235,11 @@ cdef class PhraseMatcher:
|
||||||
for i, j in enumerate(range(start, end)):
|
for i, j in enumerate(range(start, end)):
|
||||||
phrase_key[i] = doc.c[j].lex.orth
|
phrase_key[i] = doc.c[j].lex.orth
|
||||||
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
|
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
|
||||||
ent_id = <hash_t>self.phrase_ids.get(key)
|
|
||||||
if ent_id == 0:
|
ent_index = <hash_t>self.phrase_ids.get(key)
|
||||||
|
if ent_index == 0:
|
||||||
return None
|
return None
|
||||||
else:
|
return self.ent_id_matrix[ent_index]
|
||||||
return ent_id
|
|
||||||
|
|
||||||
def get_lex_value(self, Doc doc, int i):
|
def get_lex_value(self, Doc doc, int i):
|
||||||
if self.attr == ORTH:
|
if self.attr == ORTH:
|
||||||
|
|
|
@ -6,7 +6,6 @@ from spacy.matcher import PhraseMatcher
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue3972(en_vocab):
|
def test_issue3972(en_vocab):
|
||||||
"""Test that the PhraseMatcher returns duplicates for duplicate match IDs.
|
"""Test that the PhraseMatcher returns duplicates for duplicate match IDs.
|
||||||
"""
|
"""
|
||||||
|
@ -15,4 +14,10 @@ def test_issue3972(en_vocab):
|
||||||
matcher.add("B", None, Doc(en_vocab, words=["New", "York"]))
|
matcher.add("B", None, Doc(en_vocab, words=["New", "York"]))
|
||||||
doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"])
|
doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"])
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
|
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
|
# We should have a match for each of the two rules
|
||||||
|
found_ids = [en_vocab.strings[ent_id] for (ent_id, _, _) in matches]
|
||||||
|
assert "A" in found_ids
|
||||||
|
assert "B" in found_ids
|
||||||
|
|
Loading…
Reference in New Issue