mirror of https://github.com/explosion/spaCy.git
Fix prediction for tok2vec
This commit is contained in:
parent
f13d6c7359
commit
9b1b0742fd
|
@ -93,6 +93,7 @@ class TokenVectorEncoder(object):
|
||||||
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
||||||
"""
|
"""
|
||||||
for docs in cytoolz.partition_all(batch_size, stream):
|
for docs in cytoolz.partition_all(batch_size, stream):
|
||||||
|
docs = list(docs)
|
||||||
tokvecses = self.predict(docs)
|
tokvecses = self.predict(docs)
|
||||||
self.set_annotations(docs, tokvecses)
|
self.set_annotations(docs, tokvecses)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
@ -108,19 +109,14 @@ class TokenVectorEncoder(object):
|
||||||
return tokvecs
|
return tokvecs
|
||||||
|
|
||||||
def set_annotations(self, docs, tokvecses):
|
def set_annotations(self, docs, tokvecses):
|
||||||
for doc, tokvecs in zip(docs, tokvecses):
|
|
||||||
doc.tensor = tokvecs
|
|
||||||
|
|
||||||
def set_annotations(self, docs, tokvecs):
|
|
||||||
"""Set the tensor attribute for a batch of documents.
|
"""Set the tensor attribute for a batch of documents.
|
||||||
|
|
||||||
docs (iterable): A sequence of `Doc` objects.
|
docs (iterable): A sequence of `Doc` objects.
|
||||||
tokvecs (object): Vector representation for each token in the documents.
|
tokvecs (object): Vector representation for each token in the documents.
|
||||||
"""
|
"""
|
||||||
start = 0
|
for doc, tokvecs in zip(docs, tokvecses):
|
||||||
for doc in docs:
|
assert tokvecs.shape[0] == len(doc)
|
||||||
doc.tensor = tokvecs[start : start + len(doc)]
|
doc.tensor = tokvecs
|
||||||
start += len(doc)
|
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
||||||
"""Update the model.
|
"""Update the model.
|
||||||
|
|
Loading…
Reference in New Issue