mirror of https://github.com/explosion/spaCy.git
Fix tagger training
This commit is contained in:
parent
a2357cce3f
commit
386c1a5bd8
|
@ -343,6 +343,7 @@ class NeuralTagger(BaseThincComponent):
|
||||||
|
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
|
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
|
||||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||||
|
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||||
|
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
@ -386,15 +387,13 @@ class NeuralTagger(BaseThincComponent):
|
||||||
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||||
vocab.morphology.lemmatizer,
|
vocab.morphology.lemmatizer,
|
||||||
exc=vocab.morphology.exc)
|
exc=vocab.morphology.exc)
|
||||||
token_vector_width = pipeline[0].model.nO
|
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width,
|
self.model = self.Model(self.vocab.morphology.n_tags,
|
||||||
pretrained_dims=self.vocab.vectors_length)
|
pretrained_dims=self.vocab.vectors_length)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, n_tags, token_vector_width, pretrained_dims=0, **cfg):
|
def Model(cls, n_tags, **cfg):
|
||||||
return build_tagger_model(n_tags, token_vector_width,
|
return build_tagger_model(n_tags, **cfg)
|
||||||
pretrained_dims, **cfg)
|
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
|
|
Loading…
Reference in New Issue