Fix prediction for tok2vec

This commit is contained in:
Matthew Honnibal 2017-05-21 17:52:01 -05:00
parent f13d6c7359
commit 9b1b0742fd
1 changed files with 4 additions and 8 deletions

View File

@ -93,6 +93,7 @@ class TokenVectorEncoder(object):
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
"""
for docs in cytoolz.partition_all(batch_size, stream):
docs = list(docs)
tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses)
yield from docs
@ -108,19 +109,14 @@ class TokenVectorEncoder(object):
return tokvecs
def set_annotations(self, docs, tokvecses):
for doc, tokvecs in zip(docs, tokvecses):
doc.tensor = tokvecs
def set_annotations(self, docs, tokvecs):
"""Set the tensor attribute for a batch of documents.
docs (iterable): A sequence of `Doc` objects.
tokvecs (object): Vector representation for each token in the documents.
"""
start = 0
for doc in docs:
doc.tensor = tokvecs[start : start + len(doc)]
start += len(doc)
for doc, tokvecs in zip(docs, tokvecses):
assert tokvecs.shape[0] == len(doc)
doc.tensor = tokvecs
def update(self, docs, golds, state=None, drop=0., sgd=None):
"""Update the model.