Improve deserialization of user_data, esp. for Underscore

This commit is contained in:
Matthew Honnibal 2017-10-17 19:29:20 +02:00
parent 374819edf8
commit cdb0c426d8
1 changed files with 24 additions and 4 deletions

View File

@ -9,6 +9,7 @@ import numpy
import numpy.linalg import numpy.linalg
import struct import struct
import dill import dill
import msgpack
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.math cimport sqrt from libc.math cimport sqrt
@ -687,14 +688,22 @@ cdef class Doc:
all annotations. all annotations.
""" """
array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE] array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
# users don't mind getting a list instead of a tuple.
serializers = { serializers = {
'text': lambda: self.text, 'text': lambda: self.text,
'array_head': lambda: array_head, 'array_head': lambda: array_head,
'array_body': lambda: self.to_array(array_head), 'array_body': lambda: self.to_array(array_head),
'sentiment': lambda: self.sentiment, 'sentiment': lambda: self.sentiment,
'tensor': lambda: self.tensor, 'tensor': lambda: self.tensor,
'user_data': lambda: self.user_data
} }
if 'user_data' not in exclude and self.user_data:
user_data_keys, user_data_values = list(zip(*self.user_data.items()))
serializers['user_data_keys'] = lambda: msgpack.dumps(user_data_keys)
serializers['user_data_values'] = lambda: msgpack.dumps(user_data_values)
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
@ -711,10 +720,20 @@ cdef class Doc:
'array_body': lambda b: None, 'array_body': lambda b: None,
'sentiment': lambda b: None, 'sentiment': lambda b: None,
'tensor': lambda b: None, 'tensor': lambda b: None,
'user_data': lambda user_data: self.user_data.update(user_data) 'user_data_keys': lambda b: None,
'user_data_values': lambda b: None,
} }
msg = util.from_bytes(bytes_data, deserializers, exclude) msg = util.from_bytes(bytes_data, deserializers, exclude)
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
# users don't mind getting a list instead of a tuple.
if 'user_data' not in exclude and 'user_data_keys' in msg:
user_data_keys = msgpack.loads(msg['user_data_keys'], use_list=False)
user_data_values = msgpack.loads(msg['user_data_values'])
for key, value in zip(user_data_keys, user_data_values):
self.user_data[key] = value
cdef attr_t[:, :] attrs cdef attr_t[:, :] attrs
cdef int i, start, end, has_space cdef int i, start, end, has_space
@ -919,12 +938,13 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
def pickle_doc(doc): def pickle_doc(doc):
bytes_data = doc.to_bytes(vocab=False) bytes_data = doc.to_bytes(vocab=False, user_data=False)
return (unpickle_doc, (doc.vocab, doc.user_data, bytes_data)) return (unpickle_doc, (doc.vocab, doc.user_data, bytes_data))
def unpickle_doc(vocab, user_data, bytes_data): def unpickle_doc(vocab, user_data, bytes_data):
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data) doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data,
exclude='user_data')
return doc return doc