mirror of https://github.com/explosion/spaCy.git
Predict tags with encoder
This commit is contained in:
parent
56073a11ef
commit
94e86ae00a
|
@ -5,6 +5,7 @@ from thinc.api import chain, layerize, with_getitem
|
||||||
from thinc.neural import Model, Softmax
|
from thinc.neural import Model, Softmax
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
from .tokens.doc cimport Doc
|
||||||
from .syntax.parser cimport Parser
|
from .syntax.parser cimport Parser
|
||||||
#from .syntax.beam_parser cimport BeamParser
|
#from .syntax.beam_parser cimport BeamParser
|
||||||
from .syntax.ner cimport BiluoPushDown
|
from .syntax.ner cimport BiluoPushDown
|
||||||
|
@ -30,24 +31,42 @@ class TokenVectorEncoder(object):
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
doc.tensor = self.model([doc])[0]
|
doc.tensor = self.model([doc])[0]
|
||||||
|
self.predict_tags([doc])
|
||||||
|
|
||||||
def begin_update(self, docs, drop=0.):
|
def begin_update(self, docs, drop=0.):
|
||||||
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
|
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.tensor = tensors[i]
|
doc.tensor = tensors[i]
|
||||||
|
self.predict_tags(docs)
|
||||||
return tensors, bp_tensors
|
return tensors, bp_tensors
|
||||||
|
|
||||||
|
def predict_tags(self, docs, drop=0.):
|
||||||
|
cdef Doc doc
|
||||||
|
scores, _ = self.tagger.begin_update(docs, drop=drop)
|
||||||
|
idx = 0
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
tag_ids = scores[idx:idx+len(doc)].argmax(axis=1)
|
||||||
|
for j, tag_id in enumerate(tag_ids):
|
||||||
|
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None):
|
def update(self, docs, golds, drop=0., sgd=None):
|
||||||
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
|
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
|
||||||
losses = scores.copy()
|
losses = scores.copy()
|
||||||
idx = 0
|
idx = 0
|
||||||
for i, gold in enumerate(golds):
|
for i, gold in enumerate(golds):
|
||||||
|
if hasattr(self.tagger.ops.xp, 'scatter_add'):
|
||||||
ids = numpy.zeros((len(gold),), dtype='i')
|
ids = numpy.zeros((len(gold),), dtype='i')
|
||||||
start = idx
|
start = idx
|
||||||
for j, tag in enumerate(gold.tags):
|
for j, tag in enumerate(gold.tags):
|
||||||
ids[j] = docs[0].vocab.morphology.tag_names.index(tag)
|
ids[j] = docs[0].vocab.morphology.tag_names.index(tag)
|
||||||
idx += 1
|
idx += 1
|
||||||
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0)
|
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0)
|
||||||
|
else:
|
||||||
|
for j, tag in enumerate(gold.tags):
|
||||||
|
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
|
||||||
|
losses[idx, tag_id] -= 1.
|
||||||
|
idx += 1
|
||||||
finish_update(losses, sgd)
|
finish_update(losses, sgd)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue