diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index a5172a556..303b073db 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -41,6 +41,37 @@ cdef Code bit_append(Code code, bint bit) nogil: code.bits &= ~(one << code.length) code.length += 1 return code + + +cdef class BitArray: + cdef int length + cdef bytes data + cdef unsigned char byte + cdef unsigned char bit_of_byte + def __init__(self): + self.data = b'' + self.byte = 0 + self.bit_of_byte = 0 + + def as_bytes(self): + if self.bit_of_byte != 0: + return self.data + chr(self.byte) + else: + return self.data + + cdef int extend(self, uint64_t code, char n_bits) except -1: + cdef uint64_t one = 1 + cdef unsigned char bit_of_code + for bit_of_code in range(n_bits): + if code & (one << bit_of_code): + self.byte |= one << self.bit_of_byte + else: + self.byte &= ~(one << self.bit_of_byte) + self.bit_of_byte += 1 + if self.bit_of_byte == 8: + self.data += chr(self.byte) + self.byte = 0 + self.bit_of_byte = 0 cdef class HuffmanCodec: @@ -75,27 +106,11 @@ cdef class HuffmanCodec: assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) def encode(self, uint32_t[:] sequence): - cdef Code code - cdef bytes output = b'' - cdef unsigned char byte = 0 - cdef uint64_t one = 1 - cdef unsigned char i_of_byte = 0 - cdef unsigned char i_of_code = 0 - for index in list(sequence) + [self.eol]: - code = self.codes[index] - for i_of_code in range(code.length): - if code.bits & (one << i_of_code): - byte |= one << i_of_byte - else: - byte &= ~(one << i_of_byte) - i_of_byte += 1 - if i_of_byte == 8: - output += chr(byte) - byte = 0 - i_of_byte = 0 - if i_of_byte != 0: - output += chr(byte) - return output + cdef BitArray bits = BitArray() + for i in sequence: + bits.extend(self.codes[i].bits, self.codes[i].length) + bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length) + return bits.as_bytes() def decode(self, bytes data): node = self.nodes.back()