Fix serializer

This commit is contained in:
Matthew Honnibal 2017-05-09 18:45:18 +02:00
parent b16ae75824
commit 4efb391994
1 changed files with 21 additions and 30 deletions

View File

@ -1,4 +1,6 @@
# coding: utf8
# cython: infer_types=True
# cython: bounds_check=False
from __future__ import unicode_literals
cimport cython
@ -565,7 +567,7 @@ cdef class Doc:
for i in range(self.length):
self.c[i] = parsed[i]
def from_array(self, attrs, array):
def from_array(self, attrs, int[:, :] array):
"""
Write to a `Doc` object, from an `(M, N)` array of attributes.
"""
@ -573,34 +575,23 @@ cdef class Doc:
cdef attr_id_t attr_id
cdef TokenC* tokens = self.c
cdef int length = len(array)
cdef attr_t[:] values
# Get set up for fast loading
cdef Pool mem = Pool()
cdef int n_attrs = len(attrs)
attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
for i, attr_id in enumerate(attrs):
attr_ids[i] = attr_id
# Now load the data
for i in range(self.length):
token = &self.c[i]
for j in range(n_attrs):
Token.set_struct_attr(token, attr_ids[j], array[i, j])
# Auxiliary loading logic
for col, attr_id in enumerate(attrs):
values = array[:, col]
if attr_id == HEAD:
if attr_id == TAG:
for i in range(length):
tokens[i].head = values[i]
if values[i] >= 1:
tokens[i + values[i]].l_kids += 1
elif values[i] < 0:
tokens[i + values[i]].r_kids += 1
elif attr_id == TAG:
for i in range(length):
if values[i] != 0:
self.vocab.morphology.assign_tag(&tokens[i], values[i])
elif attr_id == POS:
for i in range(length):
tokens[i].pos = <univ_pos_t>values[i]
elif attr_id == DEP:
for i in range(length):
tokens[i].dep = values[i]
elif attr_id == ENT_IOB:
for i in range(length):
tokens[i].ent_iob = values[i]
elif attr_id == ENT_TYPE:
for i in range(length):
tokens[i].ent_type = values[i]
else:
raise ValueError("Unknown attribute ID: %d" % attr_id)
if array[i, col] != 0:
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
set_children_from_heads(self.c, self.length)
self.is_parsed = bool(HEAD in attrs or DEP in attrs)
self.is_tagged = bool(TAG in attrs or POS in attrs)
@ -645,9 +636,9 @@ cdef class Doc:
self.push_back(lex, has_space)
start = end + has_space
self.from_array(attrs[:, 2:],
[TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE])
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE],
attrs[:, 2:])
return self
def merge(self, int start_idx, int end_idx, *args, **attributes):
"""