mirror of https://github.com/explosion/spaCy.git
Fix for serialization
This commit is contained in:
parent
a1960c2d09
commit
04c32aa091
|
@ -437,10 +437,10 @@ def model_to_bytes(model):
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
if layer._mem.weights.size:
|
if isinstance(layer._mem.weights, numpy.ndarray):
|
||||||
weights.append(layer._mem.weights)
|
weights.append(layer._mem.weights)
|
||||||
else:
|
else:
|
||||||
weights.append(None)
|
weights.append(layer._mem.weights.get())
|
||||||
metas.append(tuple(layer._mem._offsets))
|
metas.append(tuple(layer._mem._offsets))
|
||||||
dims.append(getattr(layer, '_dims', None))
|
dims.append(getattr(layer, '_dims', None))
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -461,10 +461,9 @@ def model_from_bytes(model, bytes_data):
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
params = weights[i]
|
params = weights[i]
|
||||||
if params is not None:
|
flat_mem = layer._mem._mem.ravel()
|
||||||
flat_mem = layer._mem._mem.ravel()
|
flat_params = params.ravel()
|
||||||
flat_params = params.ravel()
|
flat_mem[:flat_params.size] = flat_params
|
||||||
flat_mem[:flat_params.size] = flat_params
|
|
||||||
layer._mem._offsets.update(metas[i])
|
layer._mem._offsets.update(metas[i])
|
||||||
if hasattr(layer, '_dims'):
|
if hasattr(layer, '_dims'):
|
||||||
layer._dims.update(dims[i])
|
layer._dims.update(dims[i])
|
||||||
|
|
Loading…
Reference in New Issue