mirror of https://github.com/explosion/spaCy.git
* Round-trip for serialization finally working. Needs a lot of optimization.
This commit is contained in:
parent
edd371246c
commit
5b0a7190c9
|
@ -51,13 +51,13 @@ cdef class BitArray:
|
||||||
start_byte = self.i // 8
|
start_byte = self.i // 8
|
||||||
if (self.i % 8) != 0:
|
if (self.i % 8) != 0:
|
||||||
for i in range(self.i % 8):
|
for i in range(self.i % 8):
|
||||||
yield (self.data[start_byte] & (one << i))
|
yield 1 if (self.data[start_byte] & (one << i)) else 0
|
||||||
start_byte += 1
|
start_byte += 1
|
||||||
for byte in self.data[start_byte:]:
|
for byte in self.data[start_byte:]:
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
yield byte & (one << i)
|
yield 1 if byte & (one << i) else 0
|
||||||
for i in range(self.bit_of_byte):
|
for i in range(self.bit_of_byte):
|
||||||
yield self.byte & (one << i)
|
yield 1 if self.byte & (one << i) else 0
|
||||||
|
|
||||||
def as_bytes(self):
|
def as_bytes(self):
|
||||||
if self.bit_of_byte != 0:
|
if self.bit_of_byte != 0:
|
||||||
|
@ -67,6 +67,7 @@ cdef class BitArray:
|
||||||
|
|
||||||
def append(self, bint bit):
|
def append(self, bint bit):
|
||||||
cdef uint64_t one = 1
|
cdef uint64_t one = 1
|
||||||
|
print 'append', bit
|
||||||
if bit:
|
if bit:
|
||||||
self.byte |= one << self.bit_of_byte
|
self.byte |= one << self.bit_of_byte
|
||||||
else:
|
else:
|
||||||
|
@ -128,9 +129,9 @@ cdef class HuffmanCodec:
|
||||||
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
|
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
def decode(self, BitArray bits):
|
def decode(self, bits):
|
||||||
node = self.nodes.back()
|
node = self.nodes.back()
|
||||||
symbols = []
|
symbols = []
|
||||||
for bit in bits:
|
for bit in bits:
|
||||||
branch = node.right if bit else node.left
|
branch = node.right if bit else node.left
|
||||||
if branch >= 0:
|
if branch >= 0:
|
||||||
|
|
|
@ -16,6 +16,8 @@ from .lexeme cimport check_flag
|
||||||
from .spans import Span
|
from .spans import Span
|
||||||
from .structs cimport UniStr
|
from .structs cimport UniStr
|
||||||
|
|
||||||
|
from .serialize import BitArray
|
||||||
|
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
# Compiler crashes on memory view coercion without this. Should report bug.
|
# Compiler crashes on memory view coercion without this. Should report bug.
|
||||||
from cython.view cimport array as cvarray
|
from cython.view cimport array as cvarray
|
||||||
|
@ -373,12 +375,55 @@ cdef class Doc:
|
||||||
# Return the merged Python object
|
# Return the merged Python object
|
||||||
return self[start]
|
return self[start]
|
||||||
|
|
||||||
|
def _has_trailing_space(self, int i):
|
||||||
|
cdef int end_idx = self.data[i].idx + self.data[i].lex.length
|
||||||
|
if end_idx >= len(self._string):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self._string[end_idx] == u' '
|
||||||
|
|
||||||
|
def serialize(self, bits=None):
|
||||||
|
if bits is None:
|
||||||
|
bits = BitArray()
|
||||||
|
codec = self.vocab.codec
|
||||||
|
ids = numpy.zeros(shape=(len(self),), dtype=numpy.uint32)
|
||||||
|
cdef int i
|
||||||
|
for i in range(self.length):
|
||||||
|
ids[i] = self.data[i].lex.id
|
||||||
|
bits = codec.encode(ids, bits=bits)
|
||||||
|
for i in range(self.length):
|
||||||
|
bits.append(self._has_trailing_space(i))
|
||||||
|
return bits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(Vocab vocab, bits):
|
||||||
|
biterator = iter(bits)
|
||||||
|
ids = vocab.codec.decode(biterator)
|
||||||
|
spaces = []
|
||||||
|
for bit in biterator:
|
||||||
|
spaces.append(bit)
|
||||||
|
if len(spaces) == len(ids):
|
||||||
|
break
|
||||||
|
string = u''
|
||||||
|
cdef const LexemeC* lex
|
||||||
|
for id_, space in zip(ids, spaces):
|
||||||
|
lex = vocab.lexemes[id_]
|
||||||
|
string += vocab.strings[lex.orth]
|
||||||
|
if space:
|
||||||
|
string += u' '
|
||||||
|
cdef Doc doc = Doc(vocab, string)
|
||||||
|
cdef int idx = 0
|
||||||
|
for i, id_ in enumerate(ids):
|
||||||
|
doc.push_back(idx, vocab.lexemes[id_])
|
||||||
|
idx += vocab.lexemes[id_].length
|
||||||
|
if spaces[i]:
|
||||||
|
idx += 1
|
||||||
|
return doc
|
||||||
|
|
||||||
# Enhance backwards compatibility by aliasing Doc to Tokens, for now
|
# Enhance backwards compatibility by aliasing Doc to Tokens, for now
|
||||||
Tokens = Doc
|
Tokens = Doc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Token:
|
cdef class Token:
|
||||||
"""An individual token --- i.e. a word, a punctuation symbol, etc. Created
|
"""An individual token --- i.e. a word, a punctuation symbol, etc. Created
|
||||||
via Doc.__getitem__ and Doc.__iter__.
|
via Doc.__getitem__ and Doc.__iter__.
|
||||||
|
@ -412,6 +457,10 @@ cdef class Token:
|
||||||
self.c, self.i, self.array_len,
|
self.c, self.i, self.array_len,
|
||||||
self._seq)
|
self._seq)
|
||||||
|
|
||||||
|
property lex_id:
|
||||||
|
def __get__(self):
|
||||||
|
return self.c.lex.id
|
||||||
|
|
||||||
property string:
|
property string:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if (self.i+1) == self._seq.length:
|
if (self.i+1) == self._seq.length:
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,3 +11,26 @@ def test_getitem(EN):
|
||||||
assert tokens[-1].orth_ == '.'
|
assert tokens[-1].orth_ == '.'
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
tokens[len(tokens)]
|
tokens[len(tokens)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trailing_spaces(EN):
|
||||||
|
tokens = EN(u' Give it back! He pleaded. ')
|
||||||
|
assert tokens[0].orth_ == ' '
|
||||||
|
assert not tokens._has_trailing_space(0)
|
||||||
|
assert tokens._has_trailing_space(1)
|
||||||
|
assert tokens._has_trailing_space(2)
|
||||||
|
assert not tokens._has_trailing_space(3)
|
||||||
|
assert tokens._has_trailing_space(4)
|
||||||
|
assert tokens._has_trailing_space(5)
|
||||||
|
assert not tokens._has_trailing_space(6)
|
||||||
|
assert tokens._has_trailing_space(7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize(EN):
|
||||||
|
tokens = EN(u' Give it back! He pleaded. ')
|
||||||
|
packed = tokens.serialize()
|
||||||
|
new_tokens = Doc.deserialize(EN.vocab, packed)
|
||||||
|
assert tokens.string == new_tokens.string
|
||||||
|
assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens]
|
||||||
|
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
|
||||||
|
assert [tokens._has_trailing_space(t.i) for t in tokens] == [new_tokens._has_trailing_space(t.i) for t in new_tokens]
|
||||||
|
|
|
@ -26,7 +26,8 @@ class Vocab(object):
|
||||||
return self.codec.encode(numpy.array(seq, dtype=numpy.uint32))
|
return self.codec.encode(numpy.array(seq, dtype=numpy.uint32))
|
||||||
|
|
||||||
def unpack(self, packed):
|
def unpack(self, packed):
|
||||||
return [self.symbols[i] for i in self.codec.decode(packed)]
|
ids = self.codec.decode(packed)
|
||||||
|
return [self.symbols[i] for i in ids]
|
||||||
|
|
||||||
|
|
||||||
def py_encode(symb2freq):
|
def py_encode(symb2freq):
|
||||||
|
@ -75,12 +76,9 @@ def test_round_trip():
|
||||||
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
|
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
|
||||||
'the', 'lazy', 'dog', '.']
|
'the', 'lazy', 'dog', '.']
|
||||||
strings = list(vocab.codec.strings)
|
strings = list(vocab.codec.strings)
|
||||||
for i in range(len(vocab.symbols)):
|
|
||||||
print vocab.symbols[i], strings[i]
|
|
||||||
codes = {vocab.symbols[i]: strings[i] for i in range(len(vocab.symbols))}
|
codes = {vocab.symbols[i]: strings[i] for i in range(len(vocab.symbols))}
|
||||||
packed = vocab.pack(message)
|
packed = vocab.pack(message)
|
||||||
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed)
|
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed.as_bytes())
|
||||||
print string
|
|
||||||
for word in message:
|
for word in message:
|
||||||
code = codes[word]
|
code = codes[word]
|
||||||
assert string[:len(code)] == code
|
assert string[:len(code)] == code
|
||||||
|
@ -115,16 +113,10 @@ def test_rosetta():
|
||||||
|
|
||||||
|
|
||||||
def test_vocab(EN):
|
def test_vocab(EN):
|
||||||
probs = numpy.ndarray(shape=(len(EN.vocab), 2), dtype=numpy.float32)
|
codec = EN.vocab.codec
|
||||||
for word in EN.vocab:
|
|
||||||
probs[word.id, 0] = numpy.exp(word.prob)
|
|
||||||
probs[word.id, 1] = word.id
|
|
||||||
probs.sort()
|
|
||||||
probs[:,::-1]
|
|
||||||
codec = HuffmanCodec(probs[:, 0], 0)
|
|
||||||
expected_length = 0
|
expected_length = 0
|
||||||
for i, code in enumerate(codec.strings):
|
for i, code in enumerate(codec.strings):
|
||||||
expected_length += len(code) * probs[i, 0]
|
expected_length += len(code) * numpy.exp(EN.vocab[i].prob)
|
||||||
assert 8 < expected_length < 15
|
assert 8 < expected_length < 15
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue