mirror of https://github.com/explosion/spaCy.git
* Work on compressor
This commit is contained in:
parent
14eafcab15
commit
aa7bfd932b
|
@ -31,6 +31,8 @@ cdef struct Node:
|
||||||
cdef class HuffmanCodec:
|
cdef class HuffmanCodec:
|
||||||
cdef vector[Node] nodes
|
cdef vector[Node] nodes
|
||||||
cdef vector[vector[bint]] codes
|
cdef vector[vector[bint]] codes
|
||||||
|
cdef vector[bint] oov_code
|
||||||
|
cdef uint64_t oov_symbol
|
||||||
cdef float[:] probs
|
cdef float[:] probs
|
||||||
cdef dict table
|
cdef dict table
|
||||||
def __init__(self, symbols, probs):
|
def __init__(self, symbols, probs):
|
||||||
|
@ -44,11 +46,15 @@ cdef class HuffmanCodec:
|
||||||
cdef vector[bint] path
|
cdef vector[bint] path
|
||||||
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
|
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
|
||||||
|
|
||||||
def encode(self, sequence):
|
def encode(self, uint64_t[:] sequence):
|
||||||
bits = []
|
bits = []
|
||||||
|
cdef uint64_t symbol
|
||||||
for symbol in sequence:
|
for symbol in sequence:
|
||||||
i = self.table[symbol]
|
i = <int>self.table.get(symbol)
|
||||||
code = self.codes[i]
|
if i == 0:
|
||||||
|
raise Exception("Unseen symbol: %s" % symbol)
|
||||||
|
else:
|
||||||
|
code = self.codes[i]
|
||||||
bits.extend(code)
|
bits.extend(code)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue