diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 75881848a..68428d843 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -84,12 +84,13 @@ cdef class PhraseMatcher: return (unpickle_matcher, data, None, None) def remove(self, key): - """Remove a match-rule from the matcher by match ID. + """Remove a rule from the matcher by match ID. A KeyError is raised if + the key does not exist. key (unicode): The match ID. """ if key not in self._keywords: - return + raise KeyError(key) cdef MapStruct* current_node cdef MapStruct* terminal_map cdef MapStruct* node_pointer @@ -97,13 +98,16 @@ cdef class PhraseMatcher: cdef key_t terminal_key cdef void* value cdef int c_i = 0 + cdef vector[MapStruct*] path_nodes + cdef vector[key_t] path_keys + cdef key_t key_to_remove for keyword in self._keywords[key]: current_node = self.c_map - token_trie_list = [] for token in keyword: result = map_get(current_node, token) if result: - token_trie_list.append((token, current_node)) + path_nodes.push_back(current_node) + path_keys.push_back(token) current_node = result else: # if token is not found, break out of the loop @@ -113,27 +117,25 @@ cdef class PhraseMatcher: # keywords with them result = map_get(current_node, self._terminal_hash) if current_node != NULL and result: - # if this is the only remaining key, remove unnecessary paths terminal_map = result terminal_keys = [] c_i = 0 while map_iter(terminal_map, &c_i, &terminal_key, &value): terminal_keys.append(self.vocab.strings[terminal_key]) - # TODO: not working, fix remove for unused paths/maps - if False and terminal_keys == [key]: - # we found a complete match for input keyword - token_trie_list.append((self.vocab.strings[key], terminal_map)) - token_trie_list.reverse() - for key_to_remove, py_node_pointer in token_trie_list: - node_pointer = py_node_pointer + # if this is the only remaining key, remove unnecessary paths + if terminal_keys == [key]: + while not path_nodes.empty(): + node_pointer = path_nodes.back() + path_nodes.pop_back() + key_to_remove = path_keys.back() + path_keys.pop_back() result = map_get(node_pointer, key_to_remove) if node_pointer.filled == 1: map_clear(node_pointer, key_to_remove) self.mem.free(result) - pass else: # more than one key means more than 1 path, - # delete not required path and keep the other + # delete not required path and keep the others map_clear(node_pointer, key_to_remove) self.mem.free(result) break diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 7d65d0007..486cbb984 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -84,7 +84,8 @@ def test_phrase_matcher_remove(en_vocab): assert "TEST2" not in matcher assert "TEST3" not in matcher assert len(matcher(doc)) == 0 - matcher.remove("TEST3") + with pytest.raises(KeyError): + matcher.remove("TEST3") assert "TEST1" not in matcher assert "TEST2" not in matcher assert "TEST3" not in matcher