diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index 148ac9113..a2ffec62e 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -31,6 +31,8 @@ cdef struct Node: cdef class HuffmanCodec: cdef vector[Node] nodes cdef vector[vector[bint]] codes + cdef vector[bint] oov_code + cdef uint64_t oov_symbol cdef float[:] probs cdef dict table def __init__(self, symbols, probs): @@ -44,11 +46,15 @@ cdef class HuffmanCodec: cdef vector[bint] path assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) - def encode(self, sequence): + def encode(self, uint64_t[:] sequence): bits = [] + cdef uint64_t symbol for symbol in sequence: - i = self.table[symbol] - code = self.codes[i] + i = self.table.get(symbol) + if i == 0: + raise Exception("Unseen symbol: %s" % symbol) + else: + code = self.codes[i] bits.extend(code) return bits