mirror of https://github.com/explosion/spaCy.git
* Work on huffman coder
This commit is contained in:
parent
3fb9de2d13
commit
e1a25fba32
|
@ -59,14 +59,15 @@ cdef class HuffmanCodec:
|
|||
self.table = PreshMap()
|
||||
cdef bytes symb_str
|
||||
cdef uint64_t key
|
||||
cdef uint64_t i
|
||||
cdef uint32_t i
|
||||
for i, symbol in enumerate(symbols):
|
||||
if type(symbol) == unicode or type(symbol) == bytes:
|
||||
symb_str = symbol.encode('utf8')
|
||||
key = hash64(<unsigned char*>symb_str, len(symb_str), 0)
|
||||
else:
|
||||
key = int(symbol)
|
||||
self.table[key] = i
|
||||
self.table[key] = i+1
|
||||
self.symbols = symbols
|
||||
self.probs = probs
|
||||
self.codes.resize(len(probs))
|
||||
for i in range(len(self.codes)):
|
||||
|
@ -78,11 +79,17 @@ cdef class HuffmanCodec:
|
|||
path.length = 0
|
||||
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
|
||||
|
||||
def encode(self, uint64_t[:] sequence):
|
||||
def encode(self, sequence):
|
||||
cdef vector[bint] bits
|
||||
cdef uint64_t symbol
|
||||
cdef uint64_t key
|
||||
cdef uint64_t i
|
||||
for symbol in sequence:
|
||||
i = <size_t>self.table.get(symbol)
|
||||
if type(symbol) == unicode or type(symbol) == bytes:
|
||||
symb_str = symbol.encode('utf8')
|
||||
key = hash64(<unsigned char*>symb_str, len(symb_str), 0)
|
||||
else:
|
||||
key = int(symbol)
|
||||
i = <uint32_t>self.table.get(key)
|
||||
if i == 0:
|
||||
raise Exception("Unseen symbol: %s" % symbol)
|
||||
else:
|
||||
|
@ -90,15 +97,28 @@ cdef class HuffmanCodec:
|
|||
bits.extend(code)
|
||||
return bits
|
||||
|
||||
def decode(self, bits):
|
||||
def decode(self, unsigned char[:] data):
|
||||
symbols = []
|
||||
node = self.nodes.back()
|
||||
bits = []
|
||||
cdef unsigned char byte
|
||||
cdef unsigned char one
|
||||
cdef int i = 0
|
||||
for byte_ in data:
|
||||
for i in range(7, -1, -1):
|
||||
bits.append(bool(byte & (one << i)))
|
||||
|
||||
cdef bint bit = 0
|
||||
|
||||
for bit in bits:
|
||||
branch = node.right if bit else node.left
|
||||
if branch >= 0:
|
||||
node = self.nodes.at(branch)
|
||||
else:
|
||||
symbols.append(-(branch + 1))
|
||||
symbol = self.symbols[-(branch + 1)]
|
||||
if symbol == self.eol_symbol:
|
||||
break
|
||||
symbols.append(symbol)
|
||||
node = self.nodes.back()
|
||||
return symbols
|
||||
|
||||
|
|
Loading…
Reference in New Issue