mirror of https://github.com/explosion/spaCy.git
* Add faster encode_int32 and decode_int32 methods
This commit is contained in:
parent
dd60594f41
commit
c6cd0ddce8
|
@ -4,7 +4,7 @@ from libc.stdint cimport int64_t
|
|||
from libc.stdint cimport int32_t
|
||||
from libc.stdint cimport uint64_t
|
||||
|
||||
from .bits cimport Code
|
||||
from .bits cimport BitArray, Code
|
||||
|
||||
|
||||
cdef struct Node:
|
||||
|
@ -19,3 +19,6 @@ cdef class HuffmanCodec:
|
|||
|
||||
cdef readonly list leaves
|
||||
cdef readonly dict _map
|
||||
|
||||
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1
|
||||
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1
|
||||
|
|
|
@ -61,6 +61,19 @@ cdef class HuffmanCodec:
|
|||
bits.extend(self.codes[i].bits, self.codes[i].length)
|
||||
return bits
|
||||
|
||||
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1:
|
||||
cdef int msg_i
|
||||
cdef int leaf_i
|
||||
cdef int length = 0
|
||||
for msg_i in range(msg.shape[0]):
|
||||
leaf_i = self._map.get(msg[msg_i], -1)
|
||||
if leaf_i is -1:
|
||||
return 0
|
||||
code = self.codes[leaf_i]
|
||||
bits.extend(code.bits, code.length)
|
||||
length += code.length
|
||||
return length
|
||||
|
||||
def n_bits(self, msg, overhead=0):
|
||||
cdef int i
|
||||
length = 0
|
||||
|
@ -88,8 +101,39 @@ cdef class HuffmanCodec:
|
|||
if i == n:
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
"Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
|
||||
raise Exception("Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
|
||||
|
||||
@cython.boundscheck(False)
|
||||
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1:
|
||||
cdef Node node = self.root
|
||||
cdef int branch
|
||||
|
||||
cdef int n_msg = msg.shape[0]
|
||||
cdef bytes bytes_ = bits.as_bytes()
|
||||
cdef unsigned char byte
|
||||
cdef int i_msg = 0
|
||||
cdef int i_byte = 0
|
||||
cdef int i_bit = 0
|
||||
cdef unsigned char bit
|
||||
cdef int32_t one = 1
|
||||
while i_msg < n_msg:
|
||||
byte = bytes_[i_byte]
|
||||
for i_bit in range(8):
|
||||
bit = byte & (one << i_bit)
|
||||
branch = node.right if bit else node.left
|
||||
if branch >= 0:
|
||||
node = self.nodes.at(branch)
|
||||
else:
|
||||
msg[i_msg] = self.leaves[-(branch + 1)]
|
||||
node = self.nodes.back()
|
||||
i_msg += 1
|
||||
if i_msg == n_msg:
|
||||
break
|
||||
i_byte += 1
|
||||
else:
|
||||
raise Exception("Buffer exhausted at %d/%d symbols read." % (i_msg, len(msg)))
|
||||
|
||||
|
||||
|
||||
property strings:
|
||||
@cython.boundscheck(False)
|
||||
|
|
Loading…
Reference in New Issue