mirror of https://github.com/explosion/spaCy.git
Improve efficiency of Doc.to_array
This commit is contained in:
parent
2acc907d55
commit
e10e9ad2c5
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
# cython: bounds_check=False
|
# cython: bounds_check=False
|
||||||
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
cimport cython
|
cimport cython
|
||||||
|
@ -567,7 +568,6 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
cdef attr_id_t feature
|
cdef attr_id_t feature
|
||||||
cdef np.ndarray[attr_t, ndim=1] attr_ids
|
|
||||||
cdef np.ndarray[attr_t, ndim=2] output
|
cdef np.ndarray[attr_t, ndim=2] output
|
||||||
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
||||||
if not hasattr(py_attr_ids, '__iter__') \
|
if not hasattr(py_attr_ids, '__iter__') \
|
||||||
|
@ -579,12 +579,17 @@ cdef class Doc:
|
||||||
for id_ in py_attr_ids]
|
for id_ in py_attr_ids]
|
||||||
# Make an array from the attributes --- otherwise our inner loop is
|
# Make an array from the attributes --- otherwise our inner loop is
|
||||||
# Python dict iteration.
|
# Python dict iteration.
|
||||||
attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.uint64)
|
cdef np.ndarray attr_ids = numpy.asarray(py_attr_ids, dtype='i')
|
||||||
output = numpy.ndarray(shape=(self.length, len(attr_ids)),
|
output = numpy.ndarray(shape=(self.length, len(attr_ids)),
|
||||||
dtype=numpy.uint64)
|
dtype=numpy.uint64)
|
||||||
|
c_output = <attr_t*>output.data
|
||||||
|
c_attr_ids = <attr_id_t*>attr_ids.data
|
||||||
|
cdef TokenC* token
|
||||||
|
cdef int nr_attr = attr_ids.shape[0]
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
for j, feature in enumerate(attr_ids):
|
token = &self.c[i]
|
||||||
output[i, j] = get_token_attr(&self.c[i], feature)
|
for j in range(nr_attr):
|
||||||
|
c_output[i*nr_attr + j] = get_token_attr(token, c_attr_ids[j])
|
||||||
# Handle 1d case
|
# Handle 1d case
|
||||||
return output if len(attr_ids) >= 2 else output.reshape((self.length,))
|
return output if len(attr_ids) >= 2 else output.reshape((self.length,))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue