From 2a061e27772306536be428f15b219d7ad3944b6d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 29 May 2017 17:52:08 -0500 Subject: [PATCH] Fix serialisation, for reals this time --- spacy/util.py | 51 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/spacy/util.py b/spacy/util.py index a261029d5..df66b59a8 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -12,6 +12,7 @@ import textwrap import random import numpy import io +import dill import msgpack import msgpack_numpy @@ -422,45 +423,53 @@ def from_bytes(bytes_data, setters, exclude): return msg +# This stuff really belongs in thinc -- but I expect +# to refactor how all this works in thinc anyway. +# What a mess! def model_to_bytes(model): weights = [] - metas = [] - dims = [] queue = [model] i = 0 for layer in queue: if hasattr(layer, '_mem'): - if isinstance(layer._mem.weights, numpy.ndarray): - weights.append(layer._mem.weights) - else: - weights.append(layer._mem.weights.get()) - metas.append(layer._mem._offsets) - dims.append(getattr(layer, '_dims', None)) + weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []}) + if hasattr(layer, 'seed'): + weights[-1]['seed'] = layer.seed + + for (id_, name), (start, row, shape) in layer._mem._offsets.items(): + if row == 1: + continue + param = layer._mem.get((id_, name)) + if not isinstance(layer._mem.weights, numpy.ndarray): + param = param.get() + weights[-1]['params'].append( + { + 'name': name, + 'offset': start, + 'shape': shape, + 'value': param, + } + ) i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers) - data = {'metas': ujson.dumps(metas), 'weights': weights, 'dims': ujson.dumps(dims)} - return msgpack.dumps(data) + return msgpack.dumps({'weights': weights}) def model_from_bytes(model, bytes_data): data = msgpack.loads(bytes_data) weights = data['weights'] - metas = ujson.loads(data['metas']) - dims = ujson.loads(data['dims']) queue = [model] i = 0 for layer in queue: if hasattr(layer, '_mem'): - params = weights[i] - layer._mem._get_blob(params.size) - layer._mem._i -= params.size - flat_mem = layer._mem._mem.ravel() - flat_params = params.ravel() - flat_mem[:flat_params.size] = flat_params - layer._mem._offsets.update(metas[i]) - if hasattr(layer, '_dims'): - layer._dims.update(dims[i]) + if 'seed' in weights[i]: + layer.seed = weights[i]['seed'] + for dim, value in weights[i]['dims'].items(): + setattr(layer, dim, value) + for param in weights[i]['params']: + dest = getattr(layer, param['name']) + dest[:] = param['value'] i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers)