mirror of https://github.com/explosion/spaCy.git
Allow single string attributes in doc.to_array()
Previously inputs like doc.to_array('ORTH') didn't work. Closes #3064
This commit is contained in:
parent
174e85439b
commit
76e3e695af
|
@ -594,10 +594,13 @@ cdef class Doc:
|
||||||
cdef attr_id_t feature
|
cdef attr_id_t feature
|
||||||
cdef np.ndarray[attr_t, ndim=2] output
|
cdef np.ndarray[attr_t, ndim=2] output
|
||||||
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
||||||
if not hasattr(py_attr_ids, '__iter__') \
|
# See also #3064
|
||||||
and not isinstance(py_attr_ids, basestring_):
|
if isinstance(py_attr_ids, basestring_):
|
||||||
|
# Handle inputs like doc.to_array('ORTH')
|
||||||
|
py_attr_ids = [py_attr_ids]
|
||||||
|
elif not hasattr(py_attr_ids, '__iter__'):
|
||||||
|
# Handle inputs like doc.to_array(ORTH)
|
||||||
py_attr_ids = [py_attr_ids]
|
py_attr_ids = [py_attr_ids]
|
||||||
|
|
||||||
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
||||||
py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, 'upper') else id_)
|
py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, 'upper') else id_)
|
||||||
for id_ in py_attr_ids]
|
for id_ in py_attr_ids]
|
||||||
|
|
Loading…
Reference in New Issue