mirror of https://github.com/explosion/spaCy.git
* Make huffman coder take BitArray in encode/decode. Add __iter__ method to BitArray.
This commit is contained in:
parent
af5cc926a4
commit
edd371246c
|
@ -11,6 +11,8 @@ import numpy
|
|||
|
||||
cimport cython
|
||||
|
||||
ctypedef unsigned char uchar
|
||||
|
||||
# Format
|
||||
# - Total number of bytes in message (32 bit int)
|
||||
# - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
|
||||
|
@ -33,14 +35,29 @@ cdef Code bit_append(Code code, bint bit) nogil:
|
|||
|
||||
|
||||
cdef class BitArray:
|
||||
cdef int length
|
||||
cdef bytes data
|
||||
cdef unsigned char byte
|
||||
cdef unsigned char bit_of_byte
|
||||
cdef uint32_t i
|
||||
def __init__(self):
|
||||
self.data = b''
|
||||
self.byte = 0
|
||||
self.bit_of_byte = 0
|
||||
self.i = 0
|
||||
|
||||
def __iter__(self):
|
||||
cdef uchar byte, i
|
||||
cdef uchar one = 1
|
||||
start_byte = self.i // 8
|
||||
if (self.i % 8) != 0:
|
||||
for i in range(self.i % 8):
|
||||
yield (self.data[start_byte] & (one << i))
|
||||
start_byte += 1
|
||||
for byte in self.data[start_byte:]:
|
||||
for i in range(8):
|
||||
yield byte & (one << i)
|
||||
for i in range(self.bit_of_byte):
|
||||
yield self.byte & (one << i)
|
||||
|
||||
def as_bytes(self):
|
||||
if self.bit_of_byte != 0:
|
||||
|
@ -48,6 +65,18 @@ cdef class BitArray:
|
|||
else:
|
||||
return self.data
|
||||
|
||||
def append(self, bint bit):
|
||||
cdef uint64_t one = 1
|
||||
if bit:
|
||||
self.byte |= one << self.bit_of_byte
|
||||
else:
|
||||
self.byte &= ~(one << self.bit_of_byte)
|
||||
self.bit_of_byte += 1
|
||||
if self.bit_of_byte == 8:
|
||||
self.data += chr(self.byte)
|
||||
self.byte = 0
|
||||
self.bit_of_byte = 0
|
||||
|
||||
cdef int extend(self, uint64_t code, char n_bits) except -1:
|
||||
cdef uint64_t one = 1
|
||||
cdef unsigned char bit_of_code
|
||||
|
@ -91,31 +120,28 @@ cdef class HuffmanCodec:
|
|||
path.length = 0
|
||||
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
|
||||
|
||||
def encode(self, uint32_t[:] sequence):
|
||||
cdef BitArray bits = BitArray()
|
||||
def encode(self, uint32_t[:] sequence, BitArray bits=None):
|
||||
if bits is None:
|
||||
bits = BitArray()
|
||||
for i in sequence:
|
||||
bits.extend(self.codes[i].bits, self.codes[i].length)
|
||||
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
|
||||
return bits.as_bytes()
|
||||
return bits
|
||||
|
||||
def decode(self, bytes data):
|
||||
def decode(self, BitArray bits):
|
||||
node = self.nodes.back()
|
||||
symbols = []
|
||||
cdef unsigned char byte
|
||||
cdef unsigned char i = 0
|
||||
cdef unsigned char one = 1
|
||||
for byte in data:
|
||||
for i in range(8):
|
||||
branch = node.right if (byte & (one << i)) else node.left
|
||||
if branch >= 0:
|
||||
node = self.nodes.at(branch)
|
||||
for bit in bits:
|
||||
branch = node.right if bit else node.left
|
||||
if branch >= 0:
|
||||
node = self.nodes.at(branch)
|
||||
else:
|
||||
symbol = -(branch + 1)
|
||||
if symbol == self.eol:
|
||||
return symbols
|
||||
else:
|
||||
symbol = -(branch + 1)
|
||||
if symbol == self.eol:
|
||||
return symbols
|
||||
else:
|
||||
symbols.append(symbol)
|
||||
node = self.nodes.back()
|
||||
symbols.append(symbol)
|
||||
node = self.nodes.back()
|
||||
return symbols
|
||||
|
||||
property strings:
|
||||
|
|
Loading…
Reference in New Issue