Fixes to hacky vocab pickling

This commit is contained in:
Matthew Honnibal 2017-03-07 20:58:55 +01:00
parent d814892805
commit a89c3500f6
2 changed files with 26 additions and 5 deletions

View File

@ -57,7 +57,7 @@ cdef class StringCFile:
self.size = len(data)
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
for i in range(len(data)):
self.data[i] = data
self.data[i] = data[i]
def close(self):
self.is_open = False
@ -69,13 +69,12 @@ cdef class StringCFile:
memcpy(dest, self.data, elem_size * number)
self.data += elem_size * number
cdef int write_from(self, void* src, size_t number, size_t elem_size) except -1:
cdef int write_from(self, void* src, size_t elem_size, size_t number) except -1:
write_size = number * elem_size
if (self.size + write_size) >= self._capacity:
self._capacity = (self.size + write_size) * 2
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
memcpy(self.data, src, elem_size * number)
self.data += write_size
memcpy(&self.data[self.size], src, elem_size * number)
self.size += write_size
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:

View File

@ -1,9 +1,12 @@
from __future__ import unicode_literals
import io
import pickle
import pytest
import dill as pickle
from ..strings import StringStore
from ..vocab import Vocab
from ..attrs import NORM
def test_pickle_string_store():
@ -14,4 +17,23 @@ def test_pickle_string_store():
unpickled = pickle.loads(bdata)
assert unpickled['hello'] == hello
assert unpickled['bye'] == bye
assert len(sstore) == len(unpickled)
def test_pickle_vocab():
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]})
dog = vocab[u'dog']
cat = vocab[u'cat']
assert dog.norm_ == 'do'
assert cat.norm_ == 'ca'
bdata = pickle.dumps(vocab)
unpickled = pickle.loads(bdata)
assert unpickled[u'dog'].orth == dog.orth
assert unpickled[u'cat'].orth == cat.orth
assert unpickled[u'dog'].norm == dog.norm
assert unpickled[u'cat'].norm == cat.norm
dog_ = unpickled[u'dog']
cat_ = unpickled[u'cat']
assert dog_.norm != cat_.norm