Improve deserialization of user_data, esp. for Underscore

This commit is contained in:
Matthew Honnibal 2017-10-17 19:29:20 +02:00
parent 374819edf8
commit cdb0c426d8
1 changed files with 24 additions and 4 deletions

View File

@ -9,6 +9,7 @@ import numpy
import numpy.linalg
import struct
import dill
import msgpack
from libc.string cimport memcpy, memset
from libc.math cimport sqrt
@ -687,14 +688,22 @@ cdef class Doc:
all annotations.
"""
array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
# users don't mind getting a list instead of a tuple.
serializers = {
'text': lambda: self.text,
'array_head': lambda: array_head,
'array_body': lambda: self.to_array(array_head),
'sentiment': lambda: self.sentiment,
'tensor': lambda: self.tensor,
'user_data': lambda: self.user_data
}
if 'user_data' not in exclude and self.user_data:
user_data_keys, user_data_values = list(zip(*self.user_data.items()))
serializers['user_data_keys'] = lambda: msgpack.dumps(user_data_keys)
serializers['user_data_values'] = lambda: msgpack.dumps(user_data_values)
return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, **exclude):
@ -711,10 +720,20 @@ cdef class Doc:
'array_body': lambda b: None,
'sentiment': lambda b: None,
'tensor': lambda b: None,
'user_data': lambda user_data: self.user_data.update(user_data)
'user_data_keys': lambda b: None,
'user_data_values': lambda b: None,
}
msg = util.from_bytes(bytes_data, deserializers, exclude)
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
# users don't mind getting a list instead of a tuple.
if 'user_data' not in exclude and 'user_data_keys' in msg:
user_data_keys = msgpack.loads(msg['user_data_keys'], use_list=False)
user_data_values = msgpack.loads(msg['user_data_values'])
for key, value in zip(user_data_keys, user_data_values):
self.user_data[key] = value
cdef attr_t[:, :] attrs
cdef int i, start, end, has_space
@ -919,12 +938,13 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
def pickle_doc(doc):
bytes_data = doc.to_bytes(vocab=False)
bytes_data = doc.to_bytes(vocab=False, user_data=False)
return (unpickle_doc, (doc.vocab, doc.user_data, bytes_data))
def unpickle_doc(vocab, user_data, bytes_data):
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data)
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data,
exclude='user_data')
return doc