diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 158cb9220..7c67df9c3 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -9,6 +9,7 @@ import numpy import numpy.linalg import struct import dill +import msgpack from libc.string cimport memcpy, memset from libc.math cimport sqrt @@ -687,14 +688,22 @@ cdef class Doc: all annotations. """ array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE] + # Msgpack doesn't distinguish between lists and tuples, which is + # vexing for user data. As a best guess, we *know* that within + # keys, we must have tuples. In values we just have to hope + # users don't mind getting a list instead of a tuple. serializers = { 'text': lambda: self.text, 'array_head': lambda: array_head, 'array_body': lambda: self.to_array(array_head), 'sentiment': lambda: self.sentiment, 'tensor': lambda: self.tensor, - 'user_data': lambda: self.user_data } + if 'user_data' not in exclude and self.user_data: + user_data_keys, user_data_values = list(zip(*self.user_data.items())) + serializers['user_data_keys'] = lambda: msgpack.dumps(user_data_keys) + serializers['user_data_values'] = lambda: msgpack.dumps(user_data_values) + return util.to_bytes(serializers, exclude) def from_bytes(self, bytes_data, **exclude): @@ -711,10 +720,20 @@ cdef class Doc: 'array_body': lambda b: None, 'sentiment': lambda b: None, 'tensor': lambda b: None, - 'user_data': lambda user_data: self.user_data.update(user_data) + 'user_data_keys': lambda b: None, + 'user_data_values': lambda b: None, } msg = util.from_bytes(bytes_data, deserializers, exclude) + # Msgpack doesn't distinguish between lists and tuples, which is + # vexing for user data. As a best guess, we *know* that within + # keys, we must have tuples. In values we just have to hope + # users don't mind getting a list instead of a tuple. + if 'user_data' not in exclude and 'user_data_keys' in msg: + user_data_keys = msgpack.loads(msg['user_data_keys'], use_list=False) + user_data_values = msgpack.loads(msg['user_data_values']) + for key, value in zip(user_data_keys, user_data_values): + self.user_data[key] = value cdef attr_t[:, :] attrs cdef int i, start, end, has_space @@ -919,12 +938,13 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1: def pickle_doc(doc): - bytes_data = doc.to_bytes(vocab=False) + bytes_data = doc.to_bytes(vocab=False, user_data=False) return (unpickle_doc, (doc.vocab, doc.user_data, bytes_data)) def unpickle_doc(vocab, user_data, bytes_data): - doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data) + doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, + exclude='user_data') return doc