Fix for serialization

This commit is contained in:
Matthew Honnibal 2017-05-29 13:53:32 +02:00
parent a1960c2d09
commit 04c32aa091
1 changed files with 5 additions and 6 deletions

View File

@ -437,10 +437,10 @@ def model_to_bytes(model):
i = 0 i = 0
for layer in queue: for layer in queue:
if hasattr(layer, '_mem'): if hasattr(layer, '_mem'):
if layer._mem.weights.size: if isinstance(layer._mem.weights, numpy.ndarray):
weights.append(layer._mem.weights) weights.append(layer._mem.weights)
else: else:
weights.append(None) weights.append(layer._mem.weights.get())
metas.append(tuple(layer._mem._offsets)) metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None)) dims.append(getattr(layer, '_dims', None))
i += 1 i += 1
@ -461,7 +461,6 @@ def model_from_bytes(model, bytes_data):
for layer in queue: for layer in queue:
if hasattr(layer, '_mem'): if hasattr(layer, '_mem'):
params = weights[i] params = weights[i]
if params is not None:
flat_mem = layer._mem._mem.ravel() flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel() flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params flat_mem[:flat_params.size] = flat_params