From e1a25fba32b6a0e5813c280b4e150116a115eed0 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 12 Jul 2015 19:58:05 +0200 Subject: [PATCH] * Work on huffman coder --- spacy/serialize.pyx | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index b9490eff2..9e0cf1935 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -59,14 +59,15 @@ cdef class HuffmanCodec: self.table = PreshMap() cdef bytes symb_str cdef uint64_t key - cdef uint64_t i + cdef uint32_t i for i, symbol in enumerate(symbols): if type(symbol) == unicode or type(symbol) == bytes: symb_str = symbol.encode('utf8') key = hash64(symb_str, len(symb_str), 0) else: key = int(symbol) - self.table[key] = i + self.table[key] = i+1 + self.symbols = symbols self.probs = probs self.codes.resize(len(probs)) for i in range(len(self.codes)): @@ -78,11 +79,17 @@ cdef class HuffmanCodec: path.length = 0 assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) - def encode(self, uint64_t[:] sequence): + def encode(self, sequence): cdef vector[bint] bits - cdef uint64_t symbol + cdef uint64_t key + cdef uint64_t i for symbol in sequence: - i = self.table.get(symbol) + if type(symbol) == unicode or type(symbol) == bytes: + symb_str = symbol.encode('utf8') + key = hash64(symb_str, len(symb_str), 0) + else: + key = int(symbol) + i = self.table.get(key) if i == 0: raise Exception("Unseen symbol: %s" % symbol) else: @@ -90,15 +97,28 @@ cdef class HuffmanCodec: bits.extend(code) return bits - def decode(self, bits): + def decode(self, unsigned char[:] data): symbols = [] node = self.nodes.back() + bits = [] + cdef unsigned char byte + cdef unsigned char one + cdef int i = 0 + for byte_ in data: + for i in range(7, -1, -1): + bits.append(bool(byte & (one << i))) + + cdef bint bit = 0 + for bit in bits: branch = node.right if bit else node.left if branch >= 0: node = self.nodes.at(branch) else: - symbols.append(-(branch + 1)) + symbol = self.symbols[-(branch + 1)] + if symbol == self.eol_symbol: + break + symbols.append(symbol) node = self.nodes.back() return symbols