diff --git a/spacy/index.pyx b/spacy/index.pyx index 5cdd6ab5b..b2f54c785 100644 --- a/spacy/index.pyx +++ b/spacy/index.pyx @@ -10,6 +10,7 @@ from .tokens cimport TokenC from .typedefs cimport hash_t from preshed.maps cimport MapStruct, Cell, map_get, map_set, map_init +from murmurhash.mrmr cimport hash64 cdef class Index: @@ -34,94 +35,75 @@ cdef class Index: self.counts.push_back(doc_counts) - -cdef class PosMemory: - def __init__(self, tag_names): - self.tag_names = tag_names - self.nr_tags = len(tag_names) +cdef class DecisionMemory: + def __init__(self, class_names): + self.class_names = class_names + self.n_classes = len(class_names) self.mem = Pool() self._counts = PreshCounter() - self._pos_counts = PreshCounter() + self._class_counts = PreshCounter() def __getitem__(self, ids): - cdef id_t[2] ngram - ngram[0] = ids[0] - ngram[1] = ids[1] - cdef hash_t ngram_key = hash64(ngram, 2 * sizeof(id_t), 0) - cdef hash_t[2] pos_context - pos_context[0] = ngram_key + cdef id_t[2] context + context[0] = context[0] + context[1] = context[1] + cdef hash_t context_key = hash64(context, 2 * sizeof(id_t), 0) + cdef hash_t[2] class_context + class_context[0] = context_key counts = {} cdef id_t i - for i, tag in enumerate(self.tag_names): - pos_context[1] = i - key = hash64(pos_context, sizeof(hash_t) * 2, 0) - count = self._pos_counts[key] - counts[tag] = count + for i, clas in enumerate(self.clas_names): + class_context[1] = i + key = hash64(class_context, sizeof(hash_t) * 2, 0) + count = self._class_counts[key] + counts[clas] = count return counts @cython.cdivision(True) - def iter_ngrams(self, float min_acc=0.99, count_t min_freq=10): - cdef Address counts_addr = Address(self.nr_tags, sizeof(count_t)) + def iter_contexts(self, float min_acc=0.99, count_t min_freq=10): + cdef Address counts_addr = Address(self.n_classes, sizeof(count_t)) cdef count_t* counts = counts_addr.ptr - cdef MapStruct* ngram_counts = self._counts.c_map - cdef hash_t ngram_key - cdef count_t ngram_freq - cdef int best_pos + cdef MapStruct* context_counts = self._counts.c_map + cdef hash_t context_key + cdef count_t context_freq + cdef int best_class cdef float acc cdef int i - for i in range(ngram_counts.length): - ngram_key = ngram_counts.cells[i].key - ngram_freq = ngram_counts.cells[i].value - if ngram_key != 0 and ngram_freq >= min_freq: - best_pos = self.find_best_pos(counts, ngram_key) - acc = counts[best_pos] / ngram_freq + for i in range(context_counts.length): + context_key = context_counts.cells[i].key + context_freq = context_counts.cells[i].value + if context_key != 0 and context_freq >= min_freq: + best_class = self.find_best_class(counts, context_key) + acc = counts[best_class] / context_freq if acc >= min_acc: - yield counts[best_pos], ngram_key, best_pos + yield counts[best_class], context_key, best_class - cpdef int count(self, Tokens tokens) except -1: - cdef int i - cdef TokenC* t - for i in range(tokens.length): - t = &tokens.data[i] - if t.lex.prob != 0 and t.lex.prob >= -14: - self.inc(t, 1) + cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1: + cdef hash_t context_and_class_key + cdef hash_t[2] context_and_class + context_and_class[0] = context_key + context_and_class[1] = clas + context_and_class_key = hash64(context_and_class, 2 * sizeof(hash_t), 0) + self._counts.inc(context_key, inc) + self._class_counts.inc(context_and_class_key, inc) - cdef int inc(self, TokenC* word, count_t inc) except -1: - cdef hash_t[2] ngram_pos_context - cdef hash_t ngram_key = self._ngram_key(word) - ngram_pos_context[0] = ngram_key - ngram_pos_context[1] = word.pos - ngram_pos_key = hash64(ngram_pos_context, 2 * sizeof(hash_t), 0) - self._counts.inc(ngram_key, inc) - self._pos_counts.inc(ngram_pos_key, inc) - - cdef int find_best_pos(self, count_t* counts, hash_t ngram_key) except -1: + cdef int find_best_class(self, count_t* counts, hash_t context_key) except -1: cdef hash_t[2] unhashed_key - unhashed_key[0] = ngram_key + unhashed_key[0] = context_key cdef count_t total = 0 cdef hash_t key - cdef int pos + cdef int clas cdef int best cdef int mode = 0 - for pos in range(self.nr_tags): - unhashed_key[1] = pos + for clas in range(self.n_classes): + unhashed_key[1] = clas key = hash64(unhashed_key, sizeof(hash_t) * 2, 0) - count = self._pos_counts[key] - counts[pos] = count + count = self._class_counts[key] + counts[clas] = count if count >= mode: mode = count - best = pos + best = clas total += count return best - - cdef count_t ngram_count(self, TokenC* word) except -1: - cdef hash_t ngram_key = self._ngram_key(word) - return self._counts[ngram_key] - - cdef hash_t _ngram_key(self, TokenC* word) except 0: - cdef id_t[2] context - context[0] = word.lex.sic - context[1] = word[-1].lex.sic - return hash64(context, sizeof(id_t) * 2, 0)