diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index b0b440727..91217b80b 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -93,6 +93,7 @@ class TokenVectorEncoder(object): YIELDS (iterator): A sequence of `Doc` objects, in order of input. """ for docs in cytoolz.partition_all(batch_size, stream): + docs = list(docs) tokvecses = self.predict(docs) self.set_annotations(docs, tokvecses) yield from docs @@ -108,19 +109,14 @@ class TokenVectorEncoder(object): return tokvecs def set_annotations(self, docs, tokvecses): - for doc, tokvecs in zip(docs, tokvecses): - doc.tensor = tokvecs - - def set_annotations(self, docs, tokvecs): """Set the tensor attribute for a batch of documents. docs (iterable): A sequence of `Doc` objects. tokvecs (object): Vector representation for each token in the documents. """ - start = 0 - for doc in docs: - doc.tensor = tokvecs[start : start + len(doc)] - start += len(doc) + for doc, tokvecs in zip(docs, tokvecses): + assert tokvecs.shape[0] == len(doc) + doc.tensor = tokvecs def update(self, docs, golds, state=None, drop=0., sgd=None): """Update the model.