Fix prediction for tok2vec

This commit is contained in:
Matthew Honnibal 2017-05-21 17:52:01 -05:00
parent f13d6c7359
commit 9b1b0742fd
1 changed files with 4 additions and 8 deletions

View File

@ -93,6 +93,7 @@ class TokenVectorEncoder(object):
YIELDS (iterator): A sequence of `Doc` objects, in order of input. YIELDS (iterator): A sequence of `Doc` objects, in order of input.
""" """
for docs in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):
docs = list(docs)
tokvecses = self.predict(docs) tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses) self.set_annotations(docs, tokvecses)
yield from docs yield from docs
@ -108,19 +109,14 @@ class TokenVectorEncoder(object):
return tokvecs return tokvecs
def set_annotations(self, docs, tokvecses): def set_annotations(self, docs, tokvecses):
for doc, tokvecs in zip(docs, tokvecses):
doc.tensor = tokvecs
def set_annotations(self, docs, tokvecs):
"""Set the tensor attribute for a batch of documents. """Set the tensor attribute for a batch of documents.
docs (iterable): A sequence of `Doc` objects. docs (iterable): A sequence of `Doc` objects.
tokvecs (object): Vector representation for each token in the documents. tokvecs (object): Vector representation for each token in the documents.
""" """
start = 0 for doc, tokvecs in zip(docs, tokvecses):
for doc in docs: assert tokvecs.shape[0] == len(doc)
doc.tensor = tokvecs[start : start + len(doc)] doc.tensor = tokvecs
start += len(doc)
def update(self, docs, golds, state=None, drop=0., sgd=None): def update(self, docs, golds, state=None, drop=0., sgd=None):
"""Update the model. """Update the model.