From 281f1faefb649bf8bef022cf9c1eb99956f2e315 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 12 Jul 2015 23:48:46 +0200 Subject: [PATCH] * Nearly finished huffman coder --- spacy/serialize.pyx | 113 +++++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index 9e0cf1935..1880d1a0f 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -11,19 +11,12 @@ import numpy cimport cython - -#cdef class Serializer: -# def __init__(self, Vocab vocab): -# pass -# -# def dump(self, Doc tokens, file_): -# pass -# # Format -# # - Total number of bytes in message (32 bit int) -# # - Words, terminating in an EOL symbol, huffman coded ~12 bits per word -# # - Spaces ~1 bit per word -# # - Parse: Huffman coded head offset / dep label / POS tag / entity IOB tag -# # combo. ? bits per word. 40 * 80 * 40 * 12 = 1.5m symbol vocab +# Format +# - Total number of bytes in message (32 bit int) +# - Words, terminating in an EOL symbol, huffman coded ~12 bits per word +# - Spaces ~1 bit per word +# - Parse: Huffman coded head offset / dep label / POS tag / entity IOB tag +# combo. ? bits per word. 40 * 80 * 40 * 12 = 1.5m symbol vocab cdef struct Node: @@ -53,21 +46,11 @@ cdef Code bit_append(Code code, bint bit) nogil: cdef class HuffmanCodec: cdef vector[Node] nodes cdef vector[Code] codes - cdef float[:] probs + cdef readonly float[:] probs cdef PreshMap table - def __init__(self, symbols, probs): - self.table = PreshMap() - cdef bytes symb_str - cdef uint64_t key - cdef uint32_t i - for i, symbol in enumerate(symbols): - if type(symbol) == unicode or type(symbol) == bytes: - symb_str = symbol.encode('utf8') - key = hash64(symb_str, len(symb_str), 0) - else: - key = int(symbol) - self.table[key] = i+1 - self.symbols = symbols + cdef uint32_t eol + def __init__(self, probs, eol): + self.eol = eol self.probs = probs self.codes.resize(len(probs)) for i in range(len(self.codes)): @@ -79,47 +62,47 @@ cdef class HuffmanCodec: path.length = 0 assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) - def encode(self, sequence): - cdef vector[bint] bits - cdef uint64_t key - cdef uint64_t i - for symbol in sequence: - if type(symbol) == unicode or type(symbol) == bytes: - symb_str = symbol.encode('utf8') - key = hash64(symb_str, len(symb_str), 0) - else: - key = int(symbol) - i = self.table.get(key) - if i == 0: - raise Exception("Unseen symbol: %s" % symbol) - else: - code = self.codes[i] - bits.extend(code) - return bits + def encode(self, uint32_t[:] sequence): + cdef Code code + cdef bytes output = b'' + cdef unsigned char byte = 0 + cdef uint64_t one = 1 + cdef unsigned char i_of_byte = 0 + cdef unsigned char i_of_code = 0 + for index in sequence: + code = self.codes[index] + for i_of_code in range(code.length): + if code.bits & (one << i_of_code): + byte |= one << i_of_byte + else: + byte &= ~(one << i_of_byte) + i_of_byte += 1 + if i_of_byte == 8: + output += chr(byte) + byte = 0 + i_of_byte = 0 + if i_of_byte != 0: + output += chr(byte) + return output - def decode(self, unsigned char[:] data): - symbols = [] + def decode(self, bytes data): node = self.nodes.back() - bits = [] + symbols = [] cdef unsigned char byte - cdef unsigned char one - cdef int i = 0 - for byte_ in data: - for i in range(7, -1, -1): - bits.append(bool(byte & (one << i))) - - cdef bint bit = 0 - - for bit in bits: - branch = node.right if bit else node.left - if branch >= 0: - node = self.nodes.at(branch) - else: - symbol = self.symbols[-(branch + 1)] - if symbol == self.eol_symbol: - break - symbols.append(symbol) - node = self.nodes.back() + cdef unsigned char i = 0 + cdef unsigned char one = 1 + for byte in data: + for i in range(8): + branch = node.right if (byte & (one << i)) else node.left + if branch >= 0: + node = self.nodes.at(branch) + else: + symbol = -(branch + 1) + if symbol == self.eol: + return symbols + else: + symbols.append(symbol) + node = self.nodes.back() return symbols property strings: