* Work on huffman coder

This commit is contained in:
Matthew Honnibal 2015-07-12 19:58:05 +02:00
parent 3fb9de2d13
commit e1a25fba32
1 changed files with 27 additions and 7 deletions

View File

@ -59,14 +59,15 @@ cdef class HuffmanCodec:
self.table = PreshMap() self.table = PreshMap()
cdef bytes symb_str cdef bytes symb_str
cdef uint64_t key cdef uint64_t key
cdef uint64_t i cdef uint32_t i
for i, symbol in enumerate(symbols): for i, symbol in enumerate(symbols):
if type(symbol) == unicode or type(symbol) == bytes: if type(symbol) == unicode or type(symbol) == bytes:
symb_str = symbol.encode('utf8') symb_str = symbol.encode('utf8')
key = hash64(<unsigned char*>symb_str, len(symb_str), 0) key = hash64(<unsigned char*>symb_str, len(symb_str), 0)
else: else:
key = int(symbol) key = int(symbol)
self.table[key] = i self.table[key] = i+1
self.symbols = symbols
self.probs = probs self.probs = probs
self.codes.resize(len(probs)) self.codes.resize(len(probs))
for i in range(len(self.codes)): for i in range(len(self.codes)):
@ -78,11 +79,17 @@ cdef class HuffmanCodec:
path.length = 0 path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, uint64_t[:] sequence): def encode(self, sequence):
cdef vector[bint] bits cdef vector[bint] bits
cdef uint64_t symbol cdef uint64_t key
cdef uint64_t i
for symbol in sequence: for symbol in sequence:
i = <size_t>self.table.get(symbol) if type(symbol) == unicode or type(symbol) == bytes:
symb_str = symbol.encode('utf8')
key = hash64(<unsigned char*>symb_str, len(symb_str), 0)
else:
key = int(symbol)
i = <uint32_t>self.table.get(key)
if i == 0: if i == 0:
raise Exception("Unseen symbol: %s" % symbol) raise Exception("Unseen symbol: %s" % symbol)
else: else:
@ -90,15 +97,28 @@ cdef class HuffmanCodec:
bits.extend(code) bits.extend(code)
return bits return bits
def decode(self, bits): def decode(self, unsigned char[:] data):
symbols = [] symbols = []
node = self.nodes.back() node = self.nodes.back()
bits = []
cdef unsigned char byte
cdef unsigned char one
cdef int i = 0
for byte_ in data:
for i in range(7, -1, -1):
bits.append(bool(byte & (one << i)))
cdef bint bit = 0
for bit in bits: for bit in bits:
branch = node.right if bit else node.left branch = node.right if bit else node.left
if branch >= 0: if branch >= 0:
node = self.nodes.at(branch) node = self.nodes.at(branch)
else: else:
symbols.append(-(branch + 1)) symbol = self.symbols[-(branch + 1)]
if symbol == self.eol_symbol:
break
symbols.append(symbol)
node = self.nodes.back() node = self.nodes.back()
return symbols return symbols