From bb0ba1f0cddb4b60ca83359d5d6738f3aedce3ad Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 20 Jul 2015 03:27:59 +0200 Subject: [PATCH] * Improve serialization speed --- spacy/serialize/huffman.pyx | 3 +++ spacy/serialize/packer.pyx | 48 ++++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/spacy/serialize/huffman.pyx b/spacy/serialize/huffman.pyx index 27b88f5ae..87a8cc41a 100644 --- a/spacy/serialize/huffman.pyx +++ b/spacy/serialize/huffman.pyx @@ -1,3 +1,4 @@ +# cython: profile=True cimport cython from libcpp.queue cimport priority_queue from libcpp.pair cimport pair @@ -74,6 +75,8 @@ cdef class HuffmanCodec: node = self.root cdef int i = 0 cdef int n = len(msg) + cdef int branch + cdef bint bit for bit in bits: branch = node.right if bit else node.left if branch >= 0: diff --git a/spacy/serialize/packer.pyx b/spacy/serialize/packer.pyx index 7ed4e85e9..f0e376410 100644 --- a/spacy/serialize/packer.pyx +++ b/spacy/serialize/packer.pyx @@ -106,17 +106,21 @@ cdef class Packer: return cls(vocab, util.read_encoding_freqs(data_dir)) def pack(self, Doc doc): - orths = [t.orth for t in doc] - chars = doc.string.encode('utf8') + orths = doc.to_array([ORTH]) + orths = orths[:, 0] + cdef bytes chars = doc.string.encode('utf8') # n_bits returns nan for oov words, i.e. can't encode message. # So, it's important to write the conditional like this. if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1): - bits = self._orth_encode(doc) + bits = self._orth_encode(doc, orths) else: - bits = self._char_encode(doc) - array = doc.to_array(self.attrs) - for i, codec in enumerate(self._codecs): - codec.encode(array[:, i], bits) + bits = self._char_encode(doc, chars) + + cdef int i + if self.attrs: + array = doc.to_array(self.attrs) + for i, codec in enumerate(self._codecs): + codec.encode(array[:, i], bits) return bits def unpack(self, BitArray bits): @@ -134,9 +138,8 @@ cdef class Packer: doc.from_array(self.attrs, array) return doc - def _orth_encode(self, Doc doc): + def _orth_encode(self, Doc doc, attr_t[:] orths): cdef BitArray bits = BitArray() - orths = [w.orth for w in doc] cdef int32_t length = len(doc) bits.extend(length, 32) self.orth_codec.encode(orths, bits) @@ -145,46 +148,47 @@ cdef class Packer: return bits def _orth_decode(self, BitArray bits, n): - orths = [0] * n + orths = numpy.ndarray(shape=(n,), dtype=numpy.int32) self.orth_codec.decode(bits, orths) orths_and_spaces = zip(orths, bits) cdef Doc doc = Doc(self.vocab, orths_and_spaces) return doc - def _char_encode(self, Doc doc): + def _char_encode(self, Doc doc, bytes utf8_str): cdef BitArray bits = BitArray() - cdef bytes utf8_str = doc.string.encode('utf8') cdef int32_t length = len(utf8_str) # Signal chars with negative length bits.extend(-length, 32) self.char_codec.encode(utf8_str, bits) - for token in doc: - for i in range(len(token)-1): + cdef int i, j + for i in range(doc.length): + for j in range(doc.data[i].lex.length-1): bits.append(False) bits.append(True) - if token.whitespace_: + if doc.data[i].spacy: bits.append(False) return bits def _char_decode(self, BitArray bits, n): - chars = [b''] * n - self.char_codec.decode(bits, chars) - cdef bytes utf8_str = b''.join(chars) + cdef bytearray utf8_str = bytearray(n) + self.char_codec.decode(bits, utf8_str) cdef unicode string = utf8_str.decode('utf8') cdef Doc tokens = Doc(self.vocab) - cdef int i cdef int start = 0 cdef bint is_spacy cdef UniStr span cdef int length = len(string) - iter_bits = iter(bits) - for i in range(length): - is_end_token = iter_bits.next() + cdef int i = 0 + cdef bint is_end_token + for is_end_token in bits: if is_end_token: slice_unicode(&span, string, start, i+1) lex = self.vocab.get(tokens.mem, &span) is_spacy = (i+1) < length and string[i+1] == u' ' tokens.push_back(lex, is_spacy) start = i + 1 + is_spacy + i += 1 + if i >= n: + break return tokens