Fix initialization of vectors, to address serialization problem

This commit is contained in:
Matthew Honnibal 2017-10-20 13:59:24 +02:00
parent 61bc203f3f
commit 92ac9316b5
2 changed files with 9 additions and 13 deletions

View File

@ -32,22 +32,20 @@ cdef class Vectors:
cdef public object keys cdef public object keys
cdef public int i cdef public int i
def __init__(self, strings, data_or_width=0): def __init__(self, strings, data=None, width=0):
if isinstance(strings, StringStore): if isinstance(strings, StringStore):
self.strings = strings self.strings = strings
else: else:
self.strings = StringStore() self.strings = StringStore()
for string in strings: for string in strings:
self.strings.add(string) self.strings.add(string)
if isinstance(data_or_width, int): if data is not None:
self.data = data = numpy.zeros((len(strings), data_or_width), self.data = numpy.asarray(data, dtype='f')
dtype='f')
else: else:
data = data_or_width self.data = numpy.zeros((len(self.strings), width), dtype='f')
self.i = 0 self.i = 0
self.data = data
self.key2row = {} self.key2row = {}
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))

View File

@ -62,12 +62,10 @@ cdef class Vocab:
if strings: if strings:
for string in strings: for string in strings:
_ = self[string] _ = self[string]
for name in tag_map.keys():
if name:
self.strings.add(name)
self.lex_attr_getters = lex_attr_getters self.lex_attr_getters = lex_attr_getters
print("Create morphology", list(self.strings), tag_map)
self.morphology = Morphology(self.strings, tag_map, lemmatizer) self.morphology = Morphology(self.strings, tag_map, lemmatizer)
self.vectors = Vectors(self.strings) self.vectors = Vectors(self.strings, width=0)
property lang: property lang:
def __get__(self): def __get__(self):
@ -338,7 +336,7 @@ cdef class Vocab:
if self.vectors is None: if self.vectors is None:
return None return None
else: else:
return self.vectors.to_bytes(exclude='strings.json') return self.vectors.to_bytes()
getters = OrderedDict(( getters = OrderedDict((
('strings', lambda: self.strings.to_bytes()), ('strings', lambda: self.strings.to_bytes()),
@ -358,7 +356,7 @@ cdef class Vocab:
if self.vectors is None: if self.vectors is None:
return None return None
else: else:
return self.vectors.from_bytes(b, exclude='strings') return self.vectors.from_bytes(b)
setters = OrderedDict(( setters = OrderedDict((
('strings', lambda b: self.strings.from_bytes(b)), ('strings', lambda b: self.strings.from_bytes(b)),
('lexemes', lambda b: self.lexemes_from_bytes(b)), ('lexemes', lambda b: self.lexemes_from_bytes(b)),