mirror of https://github.com/explosion/spaCy.git
Cache features in doc2feats
This commit is contained in:
parent
39ea38c4b1
commit
711ad5edc4
11
spacy/_ml.py
11
spacy/_ml.py
|
@ -10,7 +10,9 @@ from thinc.neural._classes.resnet import Residual
|
|||
from thinc import describe
|
||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .tokens.doc import Doc
|
||||
|
||||
import numpy
|
||||
|
||||
|
@ -167,8 +169,13 @@ def zero_init(model):
|
|||
def doc2feats(cols=None):
|
||||
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
|
||||
def forward(docs, drop=0.):
|
||||
feats = [doc.to_array(cols) for doc in docs]
|
||||
feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
|
||||
feats = []
|
||||
for doc in docs:
|
||||
if 'cached_feats' not in doc.user_data:
|
||||
doc.user_data['cached_feats'] = model.ops.asarray(
|
||||
doc.to_array(cols),
|
||||
dtype='uint64')
|
||||
feats.append(doc.user_data['cached_feats'])
|
||||
return feats, None
|
||||
model = layerize(forward)
|
||||
model.cols = cols
|
||||
|
|
Loading…
Reference in New Issue