mirror of https://github.com/explosion/spaCy.git
More serialization fixes. Still broken
This commit is contained in:
parent
9c9ee24411
commit
6522ea6c8b
|
@ -166,6 +166,8 @@ class TokenVectorEncoder(object):
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model()
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda b: util.model_from_bytes(self.model, b)),
|
('model', lambda b: util.model_from_bytes(self.model, b)),
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||||
|
@ -278,9 +280,14 @@ class NeuralTagger(object):
|
||||||
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||||
vocab.morphology.lemmatizer)
|
vocab.morphology.lemmatizer)
|
||||||
token_vector_width = pipeline[0].model.nO
|
token_vector_width = pipeline[0].model.nO
|
||||||
self.model = with_flatten(
|
if self.model is True:
|
||||||
|
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Model(cls, n_tags, token_vector_width):
|
||||||
|
return with_flatten(
|
||||||
chain(Maxout(token_vector_width, token_vector_width),
|
chain(Maxout(token_vector_width, token_vector_width),
|
||||||
Softmax(self.vocab.morphology.n_tags, token_vector_width)))
|
Softmax(n_tags, token_vector_width)))
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
|
@ -294,11 +301,16 @@ class NeuralTagger(object):
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
def load_model(b):
|
||||||
|
if self.model is True:
|
||||||
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
|
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||||
|
util.model_from_bytes(self.model, b)
|
||||||
deserialize = {
|
deserialize = {
|
||||||
'model': lambda b: util.model_from_bytes(self.model, b),
|
'vocab': lambda b: self.vocab.from_bytes(b),
|
||||||
'vocab': lambda b: self.vocab.from_bytes(b)
|
'model': lambda b: load_model(b)
|
||||||
}
|
}
|
||||||
util.from_bytes(deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
|
@ -336,9 +348,14 @@ class NeuralLabeller(NeuralTagger):
|
||||||
if dep not in self.labels:
|
if dep not in self.labels:
|
||||||
self.labels[dep] = len(self.labels)
|
self.labels[dep] = len(self.labels)
|
||||||
token_vector_width = pipeline[0].model.nO
|
token_vector_width = pipeline[0].model.nO
|
||||||
self.model = with_flatten(
|
if self.model is True:
|
||||||
|
self.model = self.Model(len(self.labels), token_vector_width)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Model(cls, n_tags, token_vector_width):
|
||||||
|
return with_flatten(
|
||||||
chain(Maxout(token_vector_width, token_vector_width),
|
chain(Maxout(token_vector_width, token_vector_width),
|
||||||
Softmax(len(self.labels), token_vector_width)))
|
Softmax(n_tags, token_vector_width)))
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
scores = self.model.ops.flatten(scores)
|
scores = self.model.ops.flatten(scores)
|
||||||
|
@ -412,7 +429,6 @@ cdef class NeuralEntityRecognizer(NeuralParser):
|
||||||
return (NeuralEntityRecognizer, (self.vocab, self.moves, self.model), None, None)
|
return (NeuralEntityRecognizer, (self.vocab, self.moves, self.model), None, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class BeamDependencyParser(BeamParser):
|
cdef class BeamDependencyParser(BeamParser):
|
||||||
TransitionSystem = ArcEager
|
TransitionSystem = ArcEager
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue