mirror of https://github.com/explosion/spaCy.git
* Add faster encode_int32 and decode_int32 methods
This commit is contained in:
parent
dd60594f41
commit
c6cd0ddce8
|
@ -4,7 +4,7 @@ from libc.stdint cimport int64_t
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
from libc.stdint cimport uint64_t
|
from libc.stdint cimport uint64_t
|
||||||
|
|
||||||
from .bits cimport Code
|
from .bits cimport BitArray, Code
|
||||||
|
|
||||||
|
|
||||||
cdef struct Node:
|
cdef struct Node:
|
||||||
|
@ -19,3 +19,6 @@ cdef class HuffmanCodec:
|
||||||
|
|
||||||
cdef readonly list leaves
|
cdef readonly list leaves
|
||||||
cdef readonly dict _map
|
cdef readonly dict _map
|
||||||
|
|
||||||
|
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1
|
||||||
|
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1
|
||||||
|
|
|
@ -61,6 +61,19 @@ cdef class HuffmanCodec:
|
||||||
bits.extend(self.codes[i].bits, self.codes[i].length)
|
bits.extend(self.codes[i].bits, self.codes[i].length)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
|
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1:
|
||||||
|
cdef int msg_i
|
||||||
|
cdef int leaf_i
|
||||||
|
cdef int length = 0
|
||||||
|
for msg_i in range(msg.shape[0]):
|
||||||
|
leaf_i = self._map.get(msg[msg_i], -1)
|
||||||
|
if leaf_i is -1:
|
||||||
|
return 0
|
||||||
|
code = self.codes[leaf_i]
|
||||||
|
bits.extend(code.bits, code.length)
|
||||||
|
length += code.length
|
||||||
|
return length
|
||||||
|
|
||||||
def n_bits(self, msg, overhead=0):
|
def n_bits(self, msg, overhead=0):
|
||||||
cdef int i
|
cdef int i
|
||||||
length = 0
|
length = 0
|
||||||
|
@ -88,8 +101,39 @@ cdef class HuffmanCodec:
|
||||||
if i == n:
|
if i == n:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception("Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
|
||||||
"Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1:
|
||||||
|
cdef Node node = self.root
|
||||||
|
cdef int branch
|
||||||
|
|
||||||
|
cdef int n_msg = msg.shape[0]
|
||||||
|
cdef bytes bytes_ = bits.as_bytes()
|
||||||
|
cdef unsigned char byte
|
||||||
|
cdef int i_msg = 0
|
||||||
|
cdef int i_byte = 0
|
||||||
|
cdef int i_bit = 0
|
||||||
|
cdef unsigned char bit
|
||||||
|
cdef int32_t one = 1
|
||||||
|
while i_msg < n_msg:
|
||||||
|
byte = bytes_[i_byte]
|
||||||
|
for i_bit in range(8):
|
||||||
|
bit = byte & (one << i_bit)
|
||||||
|
branch = node.right if bit else node.left
|
||||||
|
if branch >= 0:
|
||||||
|
node = self.nodes.at(branch)
|
||||||
|
else:
|
||||||
|
msg[i_msg] = self.leaves[-(branch + 1)]
|
||||||
|
node = self.nodes.back()
|
||||||
|
i_msg += 1
|
||||||
|
if i_msg == n_msg:
|
||||||
|
break
|
||||||
|
i_byte += 1
|
||||||
|
else:
|
||||||
|
raise Exception("Buffer exhausted at %d/%d symbols read." % (i_msg, len(msg)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
property strings:
|
property strings:
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
|
|
Loading…
Reference in New Issue