mirror of https://github.com/explosion/spaCy.git
Move weight serialization to Thinc
This commit is contained in:
parent
67ade63fc4
commit
ae8010b526
|
@ -3,8 +3,8 @@ pathlib
|
||||||
numpy>=1.7
|
numpy>=1.7
|
||||||
cymem>=1.30,<1.32
|
cymem>=1.30,<1.32
|
||||||
preshed>=1.0.0,<2.0.0
|
preshed>=1.0.0,<2.0.0
|
||||||
thinc>=6.6.0,<6.7.0
|
thinc>=6.7.0,<6.8.0
|
||||||
murmurhash>=0.26,<0.27
|
murmurhash>=0.28,<0.29
|
||||||
plac<1.0.0,>=0.9.6
|
plac<1.0.0,>=0.9.6
|
||||||
six
|
six
|
||||||
ujson>=1.35
|
ujson>=1.35
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -188,10 +188,10 @@ def setup_package():
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'numpy>=1.7',
|
'numpy>=1.7',
|
||||||
'murmurhash>=0.26,<0.27',
|
'murmurhash>=0.28,<0.29',
|
||||||
'cymem>=1.30,<1.32',
|
'cymem>=1.30,<1.32',
|
||||||
'preshed>=1.0.0,<2.0.0',
|
'preshed>=1.0.0,<2.0.0',
|
||||||
'thinc>=6.6.0,<6.7.0',
|
'thinc>=6.7.0,<6.8.0',
|
||||||
'plac<1.0.0,>=0.9.6',
|
'plac<1.0.0,>=0.9.6',
|
||||||
'pip>=9.0.0,<10.0.0',
|
'pip>=9.0.0,<10.0.0',
|
||||||
'six',
|
'six',
|
||||||
|
|
|
@ -169,7 +169,7 @@ class TokenVectorEncoder(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model()
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda b: util.model_from_bytes(self.model, b)),
|
('model', lambda b: self.model.from_bytes(b)),
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||||
))
|
))
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
@ -186,7 +186,7 @@ class TokenVectorEncoder(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model()
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda p: util.model_from_bytes(self.model, p.open('rb').read())),
|
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
|
||||||
('vocab', lambda p: self.vocab.from_disk(p))
|
('vocab', lambda p: self.vocab.from_disk(p))
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
@ -307,7 +307,7 @@ class NeuralTagger(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
token_vector_width = util.env_opt('token_vector_width', 128)
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||||
util.model_from_bytes(self.model, b)
|
self.model.from_bytes(b)
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b)),
|
('vocab', lambda b: self.vocab.from_bytes(b)),
|
||||||
('model', lambda b: load_model(b)),
|
('model', lambda b: load_model(b)),
|
||||||
|
@ -324,7 +324,7 @@ class NeuralTagger(object):
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
deserialize = {
|
deserialize = {
|
||||||
'model': lambda p: util.model_from_bytes(self.model, p.open('rb').read()),
|
'model': lambda p: self.model.from_bytes(p.open('rb').read()),
|
||||||
'vocab': lambda p: self.vocab.from_disk(p)
|
'vocab': lambda p: self.vocab.from_disk(p)
|
||||||
}
|
}
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
|
@ -660,10 +660,10 @@ cdef class Parser:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
with (path / 'lower_model').open('rb') as file_:
|
with (path / 'lower_model').open('rb') as file_:
|
||||||
bytes_data = file_.read()
|
bytes_data = file_.read()
|
||||||
util.model_from_bytes(self.model[0], bytes_data)
|
self.model[0].from_bytes(bytes_data)
|
||||||
with (path / 'upper_model').open('rb') as file_:
|
with (path / 'upper_model').open('rb') as file_:
|
||||||
bytes_data = file_.read()
|
bytes_data = file_.read()
|
||||||
util.model_from_bytes(self.model[1], bytes_data)
|
self.model[1].from_bytes(bytes_data)
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -691,8 +691,8 @@ cdef class Parser:
|
||||||
self.model, cfg = self.Model(self.moves.n_moves)
|
self.model, cfg = self.Model(self.moves.n_moves)
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
util.model_from_bytes(self.model[0], msg['lower_model'])
|
self.model[0].from_bytes(msg['lower_model'])
|
||||||
util.model_from_bytes(self.model[1], msg['upper_model'])
|
util.model[1].from_bytes(msg['upper_model'])
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -452,60 +452,6 @@ def from_disk(path, readers, exclude):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
# This stuff really belongs in thinc -- but I expect
|
|
||||||
# to refactor how all this works in thinc anyway.
|
|
||||||
# What a mess!
|
|
||||||
def model_to_bytes(model):
|
|
||||||
weights = []
|
|
||||||
queue = [model]
|
|
||||||
i = 0
|
|
||||||
for layer in queue:
|
|
||||||
if hasattr(layer, '_mem'):
|
|
||||||
weights.append({
|
|
||||||
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
|
|
||||||
'params': []})
|
|
||||||
if hasattr(layer, 'seed'):
|
|
||||||
weights[-1]['seed'] = layer.seed
|
|
||||||
|
|
||||||
for (id_, name), (start, row, shape) in layer._mem._offsets.items():
|
|
||||||
if row == 1:
|
|
||||||
continue
|
|
||||||
param = layer._mem.get((id_, name))
|
|
||||||
if not isinstance(layer._mem.weights, numpy.ndarray):
|
|
||||||
param = param.get()
|
|
||||||
weights[-1]['params'].append(
|
|
||||||
{
|
|
||||||
'name': name,
|
|
||||||
'offset': start,
|
|
||||||
'shape': shape,
|
|
||||||
'value': param,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
i += 1
|
|
||||||
if hasattr(layer, '_layers'):
|
|
||||||
queue.extend(layer._layers)
|
|
||||||
return msgpack.dumps({'weights': weights})
|
|
||||||
|
|
||||||
|
|
||||||
def model_from_bytes(model, bytes_data):
|
|
||||||
data = msgpack.loads(bytes_data)
|
|
||||||
weights = data['weights']
|
|
||||||
queue = [model]
|
|
||||||
i = 0
|
|
||||||
for layer in queue:
|
|
||||||
if hasattr(layer, '_mem'):
|
|
||||||
if 'seed' in weights[i]:
|
|
||||||
layer.seed = weights[i]['seed']
|
|
||||||
for dim, value in weights[i]['dims'].items():
|
|
||||||
setattr(layer, dim, value)
|
|
||||||
for param in weights[i]['params']:
|
|
||||||
dest = getattr(layer, param['name'])
|
|
||||||
copy_array(dest, param['value'])
|
|
||||||
i += 1
|
|
||||||
if hasattr(layer, '_layers'):
|
|
||||||
queue.extend(layer._layers)
|
|
||||||
|
|
||||||
|
|
||||||
def print_table(data, title=None):
|
def print_table(data, title=None):
|
||||||
"""Print data in table format.
|
"""Print data in table format.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue