mirror of https://github.com/explosion/spaCy.git
Switch from numpy array to Token.get_struct_attr
Access token attributes directly in Doc instead of making a copy of the relevant values in a numpy array. Add unsatisfactory warning for hash collision with reserved terminal hash key. (Ideally it would change the reserved terminal hash and redo the whole trie, but for now, I'm hoping there won't be collisions.)
This commit is contained in:
parent
d995a7849e
commit
3c6f1d7e3a
|
@ -86,6 +86,8 @@ class Warnings(object):
|
||||||
"previously loaded vectors. See Issue #3853.")
|
"previously loaded vectors. See Issue #3853.")
|
||||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||||
"loaded. (Shape: {shape})")
|
"loaded. (Shape: {shape})")
|
||||||
|
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
|
||||||
|
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -11,10 +11,10 @@ from cymem.cymem cimport Pool
|
||||||
from preshed.maps cimport MapStruct, map_init, map_set, map_get, map_clear
|
from preshed.maps cimport MapStruct, map_init, map_set, map_get, map_clear
|
||||||
from preshed.maps cimport map_iter, key_t
|
from preshed.maps cimport map_iter, key_t
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
|
from ..structs cimport TokenC
|
||||||
|
from ..tokens.token cimport Token
|
||||||
from ..tokens.doc cimport Doc, get_token_attr
|
from ..tokens.doc cimport Doc, get_token_attr
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||||
|
@ -42,9 +42,9 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
cdef MapStruct* c_map
|
cdef MapStruct* c_map
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef key_t _terminal_node
|
cdef key_t _terminal_hash
|
||||||
|
|
||||||
cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil
|
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
||||||
"""Initialize the PhraseMatcher.
|
"""Initialize the PhraseMatcher.
|
||||||
|
@ -66,7 +66,7 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
self._terminal_node = 1 # or random: np.random.randint(0, high=np.iinfo(np.uint64).max, dtype=np.uint64)
|
self._terminal_hash = 826361138722620965
|
||||||
map_init(self.mem, self.c_map, 8)
|
map_init(self.mem, self.c_map, 8)
|
||||||
|
|
||||||
if isinstance(attr, long):
|
if isinstance(attr, long):
|
||||||
|
@ -130,7 +130,7 @@ cdef class PhraseMatcher:
|
||||||
break
|
break
|
||||||
# remove the tokens from trie node if there are no other
|
# remove the tokens from trie node if there are no other
|
||||||
# keywords with them
|
# keywords with them
|
||||||
result = map_get(current_node, self._terminal_node)
|
result = map_get(current_node, self._terminal_hash)
|
||||||
if current_node != NULL and result:
|
if current_node != NULL and result:
|
||||||
# if this is the only remaining key, remove unnecessary paths
|
# if this is the only remaining key, remove unnecessary paths
|
||||||
terminal_map = <MapStruct*>result
|
terminal_map = <MapStruct*>result
|
||||||
|
@ -158,7 +158,7 @@ cdef class PhraseMatcher:
|
||||||
break
|
break
|
||||||
# otherwise simply remove the key
|
# otherwise simply remove the key
|
||||||
else:
|
else:
|
||||||
result = map_get(current_node, self._terminal_node)
|
result = map_get(current_node, self._terminal_hash)
|
||||||
if result:
|
if result:
|
||||||
map_clear(<MapStruct*>result, self.vocab.strings[key])
|
map_clear(<MapStruct*>result, self.vocab.strings[key])
|
||||||
|
|
||||||
|
@ -205,6 +205,9 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
current_node = self.c_map
|
current_node = self.c_map
|
||||||
for token in keyword:
|
for token in keyword:
|
||||||
|
if token == self._terminal_hash:
|
||||||
|
user_warning(Warnings.W021)
|
||||||
|
break
|
||||||
result = <MapStruct*>map_get(current_node, token)
|
result = <MapStruct*>map_get(current_node, token)
|
||||||
if not result:
|
if not result:
|
||||||
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
@ -212,11 +215,11 @@ cdef class PhraseMatcher:
|
||||||
map_set(self.mem, current_node, token, internal_node)
|
map_set(self.mem, current_node, token, internal_node)
|
||||||
result = internal_node
|
result = internal_node
|
||||||
current_node = <MapStruct*>result
|
current_node = <MapStruct*>result
|
||||||
result = <MapStruct*>map_get(current_node, self._terminal_node)
|
result = <MapStruct*>map_get(current_node, self._terminal_hash)
|
||||||
if not result:
|
if not result:
|
||||||
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
map_init(self.mem, internal_node, 8)
|
map_init(self.mem, internal_node, 8)
|
||||||
map_set(self.mem, current_node, self._terminal_node, internal_node)
|
map_set(self.mem, current_node, self._terminal_hash, internal_node)
|
||||||
result = internal_node
|
result = internal_node
|
||||||
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
||||||
|
|
||||||
|
@ -230,17 +233,13 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/phrasematcher#call
|
DOCS: https://spacy.io/api/phrasematcher#call
|
||||||
"""
|
"""
|
||||||
doc_array = self._convert_to_array(doc)
|
|
||||||
matches = []
|
matches = []
|
||||||
if doc_array is None or len(doc_array) == 0:
|
if doc is None or len(doc) == 0:
|
||||||
# if doc_array is empty or None just return empty list
|
# if doc is empty or None just return empty list
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
if not doc_array.flags['C_CONTIGUOUS']:
|
|
||||||
doc_array = np.ascontiguousarray(doc_array)
|
|
||||||
cdef key_t[::1] doc_array_memview = doc_array
|
|
||||||
cdef vector[MatchStruct] c_matches
|
cdef vector[MatchStruct] c_matches
|
||||||
self.find_matches(&doc_array_memview[0], doc_array_memview.shape[0], &c_matches)
|
self.find_matches(doc, &c_matches)
|
||||||
for i in range(c_matches.size()):
|
for i in range(c_matches.size()):
|
||||||
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
|
@ -249,7 +248,7 @@ cdef class PhraseMatcher:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil:
|
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
|
||||||
cdef MapStruct* current_node = self.c_map
|
cdef MapStruct* current_node = self.c_map
|
||||||
cdef int start = 0
|
cdef int start = 0
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
|
@ -259,22 +258,22 @@ cdef class PhraseMatcher:
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
cdef MatchStruct ms
|
cdef MatchStruct ms
|
||||||
cdef void* result
|
cdef void* result
|
||||||
while idx < hash_array_len:
|
while idx < doc.length:
|
||||||
start = idx
|
start = idx
|
||||||
token = hash_array[idx]
|
token = Token.get_struct_attr(&doc.c[idx], self.attr)
|
||||||
# look for sequences from this position
|
# look for sequences from this position
|
||||||
result = map_get(current_node, token)
|
result = map_get(current_node, token)
|
||||||
if result:
|
if result:
|
||||||
current_node = <MapStruct*>result
|
current_node = <MapStruct*>result
|
||||||
idy = idx + 1
|
idy = idx + 1
|
||||||
while idy < hash_array_len:
|
while idy < doc.length:
|
||||||
result = map_get(current_node, self._terminal_node)
|
result = map_get(current_node, self._terminal_hash)
|
||||||
if result:
|
if result:
|
||||||
i = 0
|
i = 0
|
||||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
ms = make_matchstruct(key, start, idy)
|
ms = make_matchstruct(key, start, idy)
|
||||||
matches.push_back(ms)
|
matches.push_back(ms)
|
||||||
inner_token = hash_array[idy]
|
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
||||||
result = map_get(current_node, inner_token)
|
result = map_get(current_node, inner_token)
|
||||||
if result:
|
if result:
|
||||||
current_node = <MapStruct*>result
|
current_node = <MapStruct*>result
|
||||||
|
@ -282,8 +281,8 @@ cdef class PhraseMatcher:
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# end of hash_array reached
|
# end of doc reached
|
||||||
result = map_get(current_node, self._terminal_node)
|
result = map_get(current_node, self._terminal_hash)
|
||||||
if result:
|
if result:
|
||||||
i = 0
|
i = 0
|
||||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
|
@ -325,28 +324,8 @@ cdef class PhraseMatcher:
|
||||||
else:
|
else:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def get_lex_value(self, Doc doc, int i):
|
|
||||||
if self.attr == ORTH:
|
|
||||||
# Return the regular orth value of the lexeme
|
|
||||||
return doc.c[i].lex.orth
|
|
||||||
# Get the attribute value instead, e.g. token.pos
|
|
||||||
attr_value = get_token_attr(&doc.c[i], self.attr)
|
|
||||||
if attr_value in (0, 1):
|
|
||||||
# Value is boolean, convert to string
|
|
||||||
string_attr_value = str(attr_value)
|
|
||||||
else:
|
|
||||||
string_attr_value = self.vocab.strings[attr_value]
|
|
||||||
string_attr_name = self.vocab.strings[self.attr]
|
|
||||||
# Concatenate the attr name and value to not pollute lexeme space
|
|
||||||
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
|
||||||
# create false positive matches
|
|
||||||
matcher_attr_string = "matcher:{}-{}".format(string_attr_name, string_attr_value)
|
|
||||||
# Add new string to vocab
|
|
||||||
_ = self.vocab[matcher_attr_string]
|
|
||||||
return self.vocab.strings[matcher_attr_string]
|
|
||||||
|
|
||||||
def _convert_to_array(self, Doc doc):
|
def _convert_to_array(self, Doc doc):
|
||||||
return np.array([self.get_lex_value(doc, i) for i in range(len(doc))], dtype=np.uint64)
|
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
||||||
|
|
||||||
|
|
||||||
def unpickle_matcher(vocab, docs, callbacks):
|
def unpickle_matcher(vocab, docs, callbacks):
|
||||||
|
|
Loading…
Reference in New Issue