From 0d9740e8266cfbdc74d15c708427f759f1f02588 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 19 Sep 2019 16:36:12 +0200 Subject: [PATCH] Replace PhraseMatcher with Aho-Corasick Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays of the hash values for the relevant attribute. The implementation is based on FlashText. The speed should be similar to the previous PhraseMatcher. It is now possible to easily remove match IDs and matches don't go missing with large keyword lists / vocabularies. Fixes #4308. --- spacy/matcher/phrasematcher.pxd | 5 - spacy/matcher/phrasematcher.pyx | 267 +++++++++++---------- spacy/tests/matcher/test_phrase_matcher.py | 52 ++++ 3 files changed, 189 insertions(+), 135 deletions(-) diff --git a/spacy/matcher/phrasematcher.pxd b/spacy/matcher/phrasematcher.pxd index 3aba1686f..e69de29bb 100644 --- a/spacy/matcher/phrasematcher.pxd +++ b/spacy/matcher/phrasematcher.pxd @@ -1,5 +0,0 @@ -from libcpp.vector cimport vector - -from ..typedefs cimport hash_t - -ctypedef vector[hash_t] hash_vec diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 9e8801cc1..76a66c506 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -2,28 +2,14 @@ # cython: profile=True from __future__ import unicode_literals -from libcpp.vector cimport vector -from cymem.cymem cimport Pool -from murmurhash.mrmr cimport hash64 -from preshed.maps cimport PreshMap +import numpy as np -from .matcher cimport Matcher from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t from ..vocab cimport Vocab from ..tokens.doc cimport Doc, get_token_attr -from ..typedefs cimport attr_t, hash_t from ._schemas import TOKEN_PATTERN_SCHEMA from ..errors import Errors, Warnings, deprecation_warning, user_warning -from ..attrs import FLAG61 as U_ENT -from ..attrs import FLAG60 as B2_ENT -from ..attrs import FLAG59 as B3_ENT -from ..attrs import FLAG58 as B4_ENT -from ..attrs import FLAG43 as L2_ENT -from ..attrs import FLAG42 as L3_ENT -from ..attrs import FLAG41 as L4_ENT -from ..attrs import FLAG42 as I3_ENT -from ..attrs import FLAG41 as I4_ENT cdef class PhraseMatcher: @@ -33,18 +19,18 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher + + Adapted from FlashText: https://github.com/vi3k6i5/flashtext + MIT License (see `LICENSE`) + Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com) """ - cdef Pool mem cdef Vocab vocab - cdef Matcher matcher - cdef PreshMap phrase_ids - cdef vector[hash_vec] ent_id_matrix - cdef int max_length + cdef unicode _terminal + cdef object keyword_trie_dict cdef attr_id_t attr - cdef public object _callbacks - cdef public object _patterns - cdef public object _docs - cdef public object _validate + cdef object _callbacks + cdef object _keywords + cdef bint _validate def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False): """Initialize the PhraseMatcher. @@ -58,10 +44,13 @@ cdef class PhraseMatcher: """ if max_length != 0: deprecation_warning(Warnings.W010) - self.mem = Pool() - self.max_length = max_length self.vocab = vocab - self.matcher = Matcher(self.vocab, validate=False) + self._terminal = '_terminal_' + self.keyword_trie_dict = dict() + self._callbacks = {} + self._keywords = {} + self._validate = validate + if isinstance(attr, long): self.attr = attr else: @@ -71,28 +60,15 @@ cdef class PhraseMatcher: if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]: raise ValueError(Errors.E152.format(attr=attr)) self.attr = self.vocab.strings[attr] - self.phrase_ids = PreshMap() - abstract_patterns = [ - [{U_ENT: True}], - [{B2_ENT: True}, {L2_ENT: True}], - [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], - [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], - ] - self.matcher.add("Candidate", None, *abstract_patterns) - self._callbacks = {} - self._docs = {} - self._validate = validate def __len__(self): - """Get the number of rules added to the matcher. Note that this only - returns the number of rules (identical with the number of IDs), not the - number of individual patterns. + """Get the number of match IDs added to the matcher. RETURNS (int): The number of rules. DOCS: https://spacy.io/api/phrasematcher#len """ - return len(self._docs) + return len(self._callbacks) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. @@ -102,12 +78,48 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#contains """ - cdef hash_t ent_id = self.matcher._normalize_key(key) - return ent_id in self._callbacks + return key in self._callbacks - def __reduce__(self): - data = (self.vocab, self._docs, self._callbacks) - return (unpickle_matcher, data, None, None) + def remove(self, key): + """Remove a match-rule from the matcher by match ID. + + key (unicode): The match ID. + """ + if key not in self._keywords: + return + for keyword in self._keywords[key]: + current_dict = self.keyword_trie_dict + token_trie_list = [] + for tokens in keyword: + if tokens in current_dict: + token_trie_list.append((tokens, current_dict)) + current_dict = current_dict[tokens] + else: + # if token is not found, break out of the loop + current_dict = None + break + # remove the tokens from trie dict if there are no other + # keywords with them + if current_dict and self._terminal in current_dict: + # if this is the only remaining key, remove unnecessary paths + if current_dict[self._terminal] == [key]: + # we found a complete match for input keyword + token_trie_list.append((self._terminal, current_dict)) + token_trie_list.reverse() + for key_to_remove, dict_pointer in token_trie_list: + if len(dict_pointer.keys()) == 1: + dict_pointer.pop(key_to_remove) + else: + # more than one key means more than 1 path, + # delete not required path and keep the other + dict_pointer.pop(key_to_remove) + break + # otherwise simply remove the key + else: + current_dict[self._terminal].remove(key) + + del self._keywords[key] + del self._callbacks[key] def add(self, key, on_match, *docs): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID @@ -119,17 +131,13 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#add """ - cdef Doc doc - cdef hash_t ent_id = self.matcher._normalize_key(key) - self._callbacks[ent_id] = on_match - self._docs[ent_id] = docs - cdef int length - cdef int i - cdef hash_t phrase_hash - cdef Pool mem = Pool() + + _ = self.vocab[key] + self._callbacks[key] = on_match + self._keywords.setdefault(key, []) + for doc in docs: - length = doc.length - if length == 0: + if len(doc) == 0: continue if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged: raise ValueError(Errors.E155.format()) @@ -139,33 +147,18 @@ cdef class PhraseMatcher: and self.attr not in (DEP, POS, TAG, LEMMA): string_attr = self.vocab.strings[self.attr] user_warning(Warnings.W012.format(key=key, attr=string_attr)) - tags = get_biluo(length) - phrase_key = mem.alloc(length, sizeof(attr_t)) - for i, tag in enumerate(tags): - attr_value = self.get_lex_value(doc, i) - lexeme = self.vocab[attr_value] - lexeme.set_flag(tag, True) - phrase_key[i] = lexeme.orth - phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) + keyword = self._convert_to_array(doc) + # keep track of keywords per key to make remove easier + # (would use a set, but can't hash numpy arrays) + if keyword not in self._keywords[key]: + self._keywords[key].append(keyword) + current_dict = self.keyword_trie_dict + for token in keyword: + current_dict = current_dict.setdefault(token, {}) + current_dict.setdefault(self._terminal, set()) + current_dict[self._terminal].add(key) - if phrase_hash in self.phrase_ids: - phrase_index = self.phrase_ids[phrase_hash] - ent_id_list = self.ent_id_matrix[phrase_index] - ent_id_list.append(ent_id) - self.ent_id_matrix[phrase_index] = ent_id_list - - else: - ent_id_list = hash_vec(1) - ent_id_list[0] = ent_id - new_index = self.ent_id_matrix.size() - if new_index == 0: - # PreshMaps can not contain 0 as value, so storing a dummy at 0 - self.ent_id_matrix.push_back(hash_vec(0)) - new_index = 1 - self.ent_id_matrix.push_back(ent_id_list) - self.phrase_ids.set(phrase_hash, new_index) - - def __call__(self, Doc doc): + def __call__(self, doc): """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. @@ -175,20 +168,62 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#call """ + doc_array = self._convert_to_array(doc) matches = [] - if self.attr == ORTH: - match_doc = doc - else: - # If we're not matching on the ORTH, match_doc will be a Doc whose - # token.orth values are the attribute values we're matching on, - # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) - words = [self.get_lex_value(doc, i) for i in range(len(doc))] - match_doc = Doc(self.vocab, words=words) - for _, start, end in self.matcher(match_doc): - ent_ids = self.accept_match(match_doc, start, end) - if ent_ids is not None: - for ent_id in ent_ids: - matches.append((ent_id, start, end)) + if doc_array is None or len(doc_array) == 0: + # if doc_array is empty or None just return empty list + return matches + current_dict = self.keyword_trie_dict + start = 0 + reset_current_dict = False + idx = 0 + doc_array_len = len(doc_array) + while idx < doc_array_len: + token = doc_array[idx] + # if end is present in current_dict + if self._terminal in current_dict or token in current_dict: + if self._terminal in current_dict: + ent_id = current_dict[self._terminal] + matches.append((self.vocab.strings[ent_id], start, idx)) + + # look for longer sequences from this position + if token in current_dict: + current_dict_continued = current_dict[token] + + idy = idx + 1 + while idy < doc_array_len: + inner_token = doc_array[idy] + if self._terminal in current_dict_continued: + ent_ids = current_dict_continued[self._terminal] + for ent_id in ent_ids: + matches.append((self.vocab.strings[ent_id], start, idy)) + if inner_token in current_dict_continued: + current_dict_continued = current_dict_continued[inner_token] + else: + break + idy += 1 + else: + # end of doc_array reached + if self._terminal in current_dict_continued: + ent_ids = current_dict_continued[self._terminal] + for ent_id in ent_ids: + matches.append((self.vocab.strings[ent_id], start, idy)) + current_dict = self.keyword_trie_dict + reset_current_dict = True + else: + # we reset current_dict + current_dict = self.keyword_trie_dict + reset_current_dict = True + # if we are end of doc_array and have a sequence discovered + if idx + 1 >= doc_array_len: + if self._terminal in current_dict: + ent_ids = current_dict[self._terminal] + for ent_id in ent_ids: + matches.append((self.vocab.strings[ent_id], start, doc_array_len)) + idx += 1 + if reset_current_dict: + reset_current_dict = False + start = idx for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: @@ -228,19 +263,6 @@ cdef class PhraseMatcher: else: yield doc - def accept_match(self, Doc doc, int start, int end): - cdef int i, j - cdef Pool mem = Pool() - phrase_key = mem.alloc(end-start, sizeof(attr_t)) - for i, j in enumerate(range(start, end)): - phrase_key[i] = doc.c[j].lex.orth - cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0) - - ent_index = self.phrase_ids.get(key) - if ent_index == 0: - return None - return self.ent_id_matrix[ent_index] - def get_lex_value(self, Doc doc, int i): if self.attr == ORTH: # Return the regular orth value of the lexeme @@ -256,25 +278,10 @@ cdef class PhraseMatcher: # Concatenate the attr name and value to not pollute lexeme space # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise # create false positive matches - return "matcher:{}-{}".format(string_attr_name, string_attr_value) + matcher_attr_string = "matcher:{}-{}".format(string_attr_name, string_attr_value) + # Add new string to vocab + _ = self.vocab[matcher_attr_string] + return self.vocab.strings[matcher_attr_string] - -def get_biluo(length): - if length == 0: - raise ValueError(Errors.E127) - elif length == 1: - return [U_ENT] - elif length == 2: - return [B2_ENT, L2_ENT] - elif length == 3: - return [B3_ENT, I3_ENT, L3_ENT] - else: - return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] - - -def unpickle_matcher(vocab, docs, callbacks): - matcher = PhraseMatcher(vocab) - for key, specs in docs.items(): - callback = callbacks.get(key, None) - matcher.add(key, callback, *specs) - return matcher + def _convert_to_array(self, Doc doc): + return np.array([self.get_lex_value(doc, i) for i in range(len(doc))], dtype=np.uint64) diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index b82f9a058..a9f5ac990 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -31,6 +31,58 @@ def test_phrase_matcher_contains(en_vocab): assert "TEST2" not in matcher +def test_phrase_matcher_repeated_add(en_vocab): + matcher = PhraseMatcher(en_vocab) + # match ID only gets added once + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST" in matcher + assert "TEST2" not in matcher + assert len(matcher(doc)) == 1 + + +def test_phrase_matcher_remove(en_vocab): + matcher = PhraseMatcher(en_vocab) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST" in matcher + assert "TEST2" not in matcher + assert len(matcher(doc)) == 1 + matcher.remove("TEST") + assert "TEST" not in matcher + assert "TEST2" not in matcher + assert len(matcher(doc)) == 0 + matcher.remove("TEST2") + assert "TEST" not in matcher + assert "TEST2" not in matcher + assert len(matcher(doc)) == 0 + + +def test_phrase_matcher_overlapping_with_remove(en_vocab): + matcher = PhraseMatcher(en_vocab) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + # TEST2 is added alongside TEST + matcher.add("TEST2", None, Doc(en_vocab, words=["like"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST" in matcher + assert len(matcher) == 2 + assert len(matcher(doc)) == 2 + # removing TEST does not remove the entry for TEST2 + matcher.remove("TEST") + assert "TEST" not in matcher + assert len(matcher) == 1 + assert len(matcher(doc)) == 1 + assert matcher(doc)[0][0] == en_vocab.strings["TEST2"] + # removing TEST2 removes all + matcher.remove("TEST2") + assert "TEST2" not in matcher + assert len(matcher) == 0 + assert len(matcher(doc)) == 0 + + def test_phrase_matcher_string_attrs(en_vocab): words1 = ["I", "like", "cats"] pos1 = ["PRON", "VERB", "NOUN"]