mirror of https://github.com/explosion/spaCy.git
* Round-trip for serialization finally working. Needs a lot of optimization.
This commit is contained in:
parent
edd371246c
commit
5b0a7190c9
|
@ -51,13 +51,13 @@ cdef class BitArray:
|
|||
start_byte = self.i // 8
|
||||
if (self.i % 8) != 0:
|
||||
for i in range(self.i % 8):
|
||||
yield (self.data[start_byte] & (one << i))
|
||||
yield 1 if (self.data[start_byte] & (one << i)) else 0
|
||||
start_byte += 1
|
||||
for byte in self.data[start_byte:]:
|
||||
for i in range(8):
|
||||
yield byte & (one << i)
|
||||
yield 1 if byte & (one << i) else 0
|
||||
for i in range(self.bit_of_byte):
|
||||
yield self.byte & (one << i)
|
||||
yield 1 if self.byte & (one << i) else 0
|
||||
|
||||
def as_bytes(self):
|
||||
if self.bit_of_byte != 0:
|
||||
|
@ -67,6 +67,7 @@ cdef class BitArray:
|
|||
|
||||
def append(self, bint bit):
|
||||
cdef uint64_t one = 1
|
||||
print 'append', bit
|
||||
if bit:
|
||||
self.byte |= one << self.bit_of_byte
|
||||
else:
|
||||
|
@ -128,9 +129,9 @@ cdef class HuffmanCodec:
|
|||
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
|
||||
return bits
|
||||
|
||||
def decode(self, BitArray bits):
|
||||
def decode(self, bits):
|
||||
node = self.nodes.back()
|
||||
symbols = []
|
||||
symbols = []
|
||||
for bit in bits:
|
||||
branch = node.right if bit else node.left
|
||||
if branch >= 0:
|
||||
|
|
|
@ -16,6 +16,8 @@ from .lexeme cimport check_flag
|
|||
from .spans import Span
|
||||
from .structs cimport UniStr
|
||||
|
||||
from .serialize import BitArray
|
||||
|
||||
from unidecode import unidecode
|
||||
# Compiler crashes on memory view coercion without this. Should report bug.
|
||||
from cython.view cimport array as cvarray
|
||||
|
@ -373,12 +375,55 @@ cdef class Doc:
|
|||
# Return the merged Python object
|
||||
return self[start]
|
||||
|
||||
def _has_trailing_space(self, int i):
|
||||
cdef int end_idx = self.data[i].idx + self.data[i].lex.length
|
||||
if end_idx >= len(self._string):
|
||||
return False
|
||||
else:
|
||||
return self._string[end_idx] == u' '
|
||||
|
||||
def serialize(self, bits=None):
|
||||
if bits is None:
|
||||
bits = BitArray()
|
||||
codec = self.vocab.codec
|
||||
ids = numpy.zeros(shape=(len(self),), dtype=numpy.uint32)
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
ids[i] = self.data[i].lex.id
|
||||
bits = codec.encode(ids, bits=bits)
|
||||
for i in range(self.length):
|
||||
bits.append(self._has_trailing_space(i))
|
||||
return bits
|
||||
|
||||
@staticmethod
|
||||
def deserialize(Vocab vocab, bits):
|
||||
biterator = iter(bits)
|
||||
ids = vocab.codec.decode(biterator)
|
||||
spaces = []
|
||||
for bit in biterator:
|
||||
spaces.append(bit)
|
||||
if len(spaces) == len(ids):
|
||||
break
|
||||
string = u''
|
||||
cdef const LexemeC* lex
|
||||
for id_, space in zip(ids, spaces):
|
||||
lex = vocab.lexemes[id_]
|
||||
string += vocab.strings[lex.orth]
|
||||
if space:
|
||||
string += u' '
|
||||
cdef Doc doc = Doc(vocab, string)
|
||||
cdef int idx = 0
|
||||
for i, id_ in enumerate(ids):
|
||||
doc.push_back(idx, vocab.lexemes[id_])
|
||||
idx += vocab.lexemes[id_].length
|
||||
if spaces[i]:
|
||||
idx += 1
|
||||
return doc
|
||||
|
||||
# Enhance backwards compatibility by aliasing Doc to Tokens, for now
|
||||
Tokens = Doc
|
||||
|
||||
|
||||
|
||||
cdef class Token:
|
||||
"""An individual token --- i.e. a word, a punctuation symbol, etc. Created
|
||||
via Doc.__getitem__ and Doc.__iter__.
|
||||
|
@ -412,6 +457,10 @@ cdef class Token:
|
|||
self.c, self.i, self.array_len,
|
||||
self._seq)
|
||||
|
||||
property lex_id:
|
||||
def __get__(self):
|
||||
return self.c.lex.id
|
||||
|
||||
property string:
|
||||
def __get__(self):
|
||||
if (self.i+1) == self._seq.length:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.tokens import Doc
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
@ -9,3 +11,26 @@ def test_getitem(EN):
|
|||
assert tokens[-1].orth_ == '.'
|
||||
with pytest.raises(IndexError):
|
||||
tokens[len(tokens)]
|
||||
|
||||
|
||||
def test_trailing_spaces(EN):
|
||||
tokens = EN(u' Give it back! He pleaded. ')
|
||||
assert tokens[0].orth_ == ' '
|
||||
assert not tokens._has_trailing_space(0)
|
||||
assert tokens._has_trailing_space(1)
|
||||
assert tokens._has_trailing_space(2)
|
||||
assert not tokens._has_trailing_space(3)
|
||||
assert tokens._has_trailing_space(4)
|
||||
assert tokens._has_trailing_space(5)
|
||||
assert not tokens._has_trailing_space(6)
|
||||
assert tokens._has_trailing_space(7)
|
||||
|
||||
|
||||
def test_serialize(EN):
|
||||
tokens = EN(u' Give it back! He pleaded. ')
|
||||
packed = tokens.serialize()
|
||||
new_tokens = Doc.deserialize(EN.vocab, packed)
|
||||
assert tokens.string == new_tokens.string
|
||||
assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens]
|
||||
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
|
||||
assert [tokens._has_trailing_space(t.i) for t in tokens] == [new_tokens._has_trailing_space(t.i) for t in new_tokens]
|
||||
|
|
|
@ -26,7 +26,8 @@ class Vocab(object):
|
|||
return self.codec.encode(numpy.array(seq, dtype=numpy.uint32))
|
||||
|
||||
def unpack(self, packed):
|
||||
return [self.symbols[i] for i in self.codec.decode(packed)]
|
||||
ids = self.codec.decode(packed)
|
||||
return [self.symbols[i] for i in ids]
|
||||
|
||||
|
||||
def py_encode(symb2freq):
|
||||
|
@ -75,12 +76,9 @@ def test_round_trip():
|
|||
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
|
||||
'the', 'lazy', 'dog', '.']
|
||||
strings = list(vocab.codec.strings)
|
||||
for i in range(len(vocab.symbols)):
|
||||
print vocab.symbols[i], strings[i]
|
||||
codes = {vocab.symbols[i]: strings[i] for i in range(len(vocab.symbols))}
|
||||
packed = vocab.pack(message)
|
||||
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed)
|
||||
print string
|
||||
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed.as_bytes())
|
||||
for word in message:
|
||||
code = codes[word]
|
||||
assert string[:len(code)] == code
|
||||
|
@ -115,16 +113,10 @@ def test_rosetta():
|
|||
|
||||
|
||||
def test_vocab(EN):
|
||||
probs = numpy.ndarray(shape=(len(EN.vocab), 2), dtype=numpy.float32)
|
||||
for word in EN.vocab:
|
||||
probs[word.id, 0] = numpy.exp(word.prob)
|
||||
probs[word.id, 1] = word.id
|
||||
probs.sort()
|
||||
probs[:,::-1]
|
||||
codec = HuffmanCodec(probs[:, 0], 0)
|
||||
codec = EN.vocab.codec
|
||||
expected_length = 0
|
||||
for i, code in enumerate(codec.strings):
|
||||
expected_length += len(code) * probs[i, 0]
|
||||
expected_length += len(code) * numpy.exp(EN.vocab[i].prob)
|
||||
assert 8 < expected_length < 15
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue