Fix for serialization

This commit is contained in:
Matthew Honnibal 2017-05-29 13:53:32 +02:00
parent a1960c2d09
commit 04c32aa091
1 changed files with 5 additions and 6 deletions

View File

@ -437,10 +437,10 @@ def model_to_bytes(model):
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
if layer._mem.weights.size:
if isinstance(layer._mem.weights, numpy.ndarray):
weights.append(layer._mem.weights)
else:
weights.append(None)
weights.append(layer._mem.weights.get())
metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1
@ -461,10 +461,9 @@ def model_from_bytes(model, bytes_data):
for layer in queue:
if hasattr(layer, '_mem'):
params = weights[i]
if params is not None:
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])