diff --git a/spacy/serialize/huffman.pxd b/spacy/serialize/huffman.pxd index 93ac3f035..e2f0600c8 100644 --- a/spacy/serialize/huffman.pxd +++ b/spacy/serialize/huffman.pxd @@ -4,7 +4,7 @@ from libc.stdint cimport int64_t from libc.stdint cimport int32_t from libc.stdint cimport uint64_t -from .bits cimport Code +from .bits cimport BitArray, Code cdef struct Node: @@ -19,3 +19,6 @@ cdef class HuffmanCodec: cdef readonly list leaves cdef readonly dict _map + + cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1 + cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1 diff --git a/spacy/serialize/huffman.pyx b/spacy/serialize/huffman.pyx index 87a8cc41a..5d42d1030 100644 --- a/spacy/serialize/huffman.pyx +++ b/spacy/serialize/huffman.pyx @@ -61,6 +61,19 @@ cdef class HuffmanCodec: bits.extend(self.codes[i].bits, self.codes[i].length) return bits + cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1: + cdef int msg_i + cdef int leaf_i + cdef int length = 0 + for msg_i in range(msg.shape[0]): + leaf_i = self._map.get(msg[msg_i], -1) + if leaf_i is -1: + return 0 + code = self.codes[leaf_i] + bits.extend(code.bits, code.length) + length += code.length + return length + def n_bits(self, msg, overhead=0): cdef int i length = 0 @@ -88,8 +101,39 @@ cdef class HuffmanCodec: if i == n: break else: - raise Exception( - "Buffer exhausted at %d/%d symbols read." % (i, len(msg))) + raise Exception("Buffer exhausted at %d/%d symbols read." % (i, len(msg))) + + @cython.boundscheck(False) + cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1: + cdef Node node = self.root + cdef int branch + + cdef int n_msg = msg.shape[0] + cdef bytes bytes_ = bits.as_bytes() + cdef unsigned char byte + cdef int i_msg = 0 + cdef int i_byte = 0 + cdef int i_bit = 0 + cdef unsigned char bit + cdef int32_t one = 1 + while i_msg < n_msg: + byte = bytes_[i_byte] + for i_bit in range(8): + bit = byte & (one << i_bit) + branch = node.right if bit else node.left + if branch >= 0: + node = self.nodes.at(branch) + else: + msg[i_msg] = self.leaves[-(branch + 1)] + node = self.nodes.back() + i_msg += 1 + if i_msg == n_msg: + break + i_byte += 1 + else: + raise Exception("Buffer exhausted at %d/%d symbols read." % (i_msg, len(msg))) + + property strings: @cython.boundscheck(False)