diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 260e72e40..c698c8024 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -327,7 +327,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match states[q].length += 1 q += 1 else: - ent_id = get_ent_id(&state.pattern[1]) + ent_id = get_ent_id(state.pattern) if action == MATCH: matches.push_back( MatchC(pattern_id=ent_id, start=state.start, diff --git a/spacy/matcher/phrasematcher.pxd b/spacy/matcher/phrasematcher.pxd new file mode 100644 index 000000000..3aba1686f --- /dev/null +++ b/spacy/matcher/phrasematcher.pxd @@ -0,0 +1,5 @@ +from libcpp.vector cimport vector + +from ..typedefs cimport hash_t + +ctypedef vector[hash_t] hash_vec diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 3a8bec2df..9e8801cc1 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -2,6 +2,7 @@ # cython: profile=True from __future__ import unicode_literals +from libcpp.vector cimport vector from cymem.cymem cimport Pool from murmurhash.mrmr cimport hash64 from preshed.maps cimport PreshMap @@ -37,6 +38,7 @@ cdef class PhraseMatcher: cdef Vocab vocab cdef Matcher matcher cdef PreshMap phrase_ids + cdef vector[hash_vec] ent_id_matrix cdef int max_length cdef attr_id_t attr cdef public object _callbacks @@ -145,7 +147,23 @@ cdef class PhraseMatcher: lexeme.set_flag(tag, True) phrase_key[i] = lexeme.orth phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) - self.phrase_ids.set(phrase_hash, ent_id) + + if phrase_hash in self.phrase_ids: + phrase_index = self.phrase_ids[phrase_hash] + ent_id_list = self.ent_id_matrix[phrase_index] + ent_id_list.append(ent_id) + self.ent_id_matrix[phrase_index] = ent_id_list + + else: + ent_id_list = hash_vec(1) + ent_id_list[0] = ent_id + new_index = self.ent_id_matrix.size() + if new_index == 0: + # PreshMaps can not contain 0 as value, so storing a dummy at 0 + self.ent_id_matrix.push_back(hash_vec(0)) + new_index = 1 + self.ent_id_matrix.push_back(ent_id_list) + self.phrase_ids.set(phrase_hash, new_index) def __call__(self, Doc doc): """Find all sequences matching the supplied patterns on the `Doc`. @@ -167,9 +185,10 @@ cdef class PhraseMatcher: words = [self.get_lex_value(doc, i) for i in range(len(doc))] match_doc = Doc(self.vocab, words=words) for _, start, end in self.matcher(match_doc): - ent_id = self.accept_match(match_doc, start, end) - if ent_id is not None: - matches.append((ent_id, start, end)) + ent_ids = self.accept_match(match_doc, start, end) + if ent_ids is not None: + for ent_id in ent_ids: + matches.append((ent_id, start, end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: @@ -216,11 +235,11 @@ cdef class PhraseMatcher: for i, j in enumerate(range(start, end)): phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0) - ent_id = self.phrase_ids.get(key) - if ent_id == 0: + + ent_index = self.phrase_ids.get(key) + if ent_index == 0: return None - else: - return ent_id + return self.ent_id_matrix[ent_index] def get_lex_value(self, Doc doc, int i): if self.attr == ORTH: diff --git a/spacy/tests/regression/test_issue3972.py b/spacy/tests/regression/test_issue3972.py index e82dff269..1bc762699 100644 --- a/spacy/tests/regression/test_issue3972.py +++ b/spacy/tests/regression/test_issue3972.py @@ -6,7 +6,6 @@ from spacy.matcher import PhraseMatcher from spacy.tokens import Doc -@pytest.mark.xfail def test_issue3972(en_vocab): """Test that the PhraseMatcher returns duplicates for duplicate match IDs. """ @@ -15,4 +14,10 @@ def test_issue3972(en_vocab): matcher.add("B", None, Doc(en_vocab, words=["New", "York"])) doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"]) matches = matcher(doc) + assert len(matches) == 2 + + # We should have a match for each of the two rules + found_ids = [en_vocab.strings[ent_id] for (ent_id, _, _) in matches] + assert "A" in found_ids + assert "B" in found_ids