mirror of https://github.com/explosion/spaCy.git
Fix prediction for tok2vec
This commit is contained in:
parent
f13d6c7359
commit
9b1b0742fd
|
@ -93,6 +93,7 @@ class TokenVectorEncoder(object):
|
|||
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
||||
"""
|
||||
for docs in cytoolz.partition_all(batch_size, stream):
|
||||
docs = list(docs)
|
||||
tokvecses = self.predict(docs)
|
||||
self.set_annotations(docs, tokvecses)
|
||||
yield from docs
|
||||
|
@ -108,19 +109,14 @@ class TokenVectorEncoder(object):
|
|||
return tokvecs
|
||||
|
||||
def set_annotations(self, docs, tokvecses):
|
||||
for doc, tokvecs in zip(docs, tokvecses):
|
||||
doc.tensor = tokvecs
|
||||
|
||||
def set_annotations(self, docs, tokvecs):
|
||||
"""Set the tensor attribute for a batch of documents.
|
||||
|
||||
docs (iterable): A sequence of `Doc` objects.
|
||||
tokvecs (object): Vector representation for each token in the documents.
|
||||
"""
|
||||
start = 0
|
||||
for doc in docs:
|
||||
doc.tensor = tokvecs[start : start + len(doc)]
|
||||
start += len(doc)
|
||||
for doc, tokvecs in zip(docs, tokvecses):
|
||||
assert tokvecs.shape[0] == len(doc)
|
||||
doc.tensor = tokvecs
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
||||
"""Update the model.
|
||||
|
|
Loading…
Reference in New Issue