* Add copy_state function

This commit is contained in:
Matthew Honnibal 2015-06-01 23:06:30 +02:00
parent c7876aa8b6
commit e09a08bd00
2 changed files with 31 additions and 1 deletions

View File

@ -106,7 +106,8 @@ cdef int head_in_buffer(const State *s, const int child, const int* gold) except
cdef int children_in_stack(const State *s, const int head, const int* gold) except -1
cdef int head_in_stack(const State *s, const int child, const int* gold) except -1
cdef State* new_state(Pool mem, TokenC* sent, const int sent_length) except NULL
cdef State* new_state(Pool mem, const TokenC* sent, const int sent_length) except NULL
cdef int copy_state(State* dest, const State* src) except -1
cdef int count_left_kids(const TokenC* head) nogil

View File

@ -21,9 +21,17 @@ cdef int add_dep(State *s, int head, int child, int label) except -1:
s.sent[head].r_kids |= 1 << (-dist)
s.sent[head].r_edge = child - head
# Walk up the tree, setting right edge
n_iter = 0
start = head
while s.sent[head].head != 0:
head += s.sent[head].head
s.sent[head].r_edge = child - head
n_iter += 1
if n_iter >= s.sent_len:
tree = [(i + s.sent[i].head) for i in range(s.sent_len)]
msg = "Error adding dependency (%d, %d). Could not find root of tree: %s"
msg = msg % (start, child, tree)
raise Exception(msg)
else:
s.sent[head].l_kids |= 1 << dist
s.sent[head].l_edge = (child + s.sent[child].l_edge) - head
@ -155,6 +163,27 @@ cdef State* new_state(Pool mem, const TokenC* sent, const int sent_len) except N
return s
cdef int copy_state(State* dest, const State* src) except -1:
assert dest.sent_len == src.sent_len
# Copy stack --- remember stack uses pointer arithmetic, so stack[-stack_len]
# is the last word of the stack.
dest.stack += (src.stack_len - dest.stack_len)
for i in range(src.stack_len):
dest.stack[-i] = src.stack[-i]
dest.stack_len = src.stack_len
# Copy sentence (i.e. the parse), up to and including word i.
memcpy(dest.sent, src.sent, sizeof(TokenC) * src.sent_len)
dest.i = src.i
# Copy assigned entities --- also pointer arithmetic
dest.ent += (src.ents_len - dest.ents_len)
for i in range(src.ents_len):
dest.ent[-i] = src.ent[-i]
dest.ents_len = src.ents_len
assert dest.sent[dest.i].head == src.sent[src.i].head
if dest.stack_len > 0:
assert dest.stack[0] < dest.i
# From https://en.wikipedia.org/wiki/Hamming_weight
cdef inline uint32_t _popcount(uint32_t x) nogil:
"""Find number of non-zero bits."""