From be4a640f0c3b06a76cf0fd0e5ebedacb0bbf7c2a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 30 May 2017 20:37:24 +0200 Subject: [PATCH] Fix arc eager label costs for uint64 --- spacy/gold.pxd | 1 + spacy/gold.pyx | 1 + spacy/syntax/arc_eager.pyx | 37 +++++++++++++++++++------------------ 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/spacy/gold.pxd b/spacy/gold.pxd index c8eadbd31..364e083fb 100644 --- a/spacy/gold.pxd +++ b/spacy/gold.pxd @@ -8,6 +8,7 @@ from .syntax.transition_system cimport Transition cdef struct GoldParseC: int* tags int* heads + int* has_dep attr_t* labels int** brackets Transition* ner diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 4290c13cf..de48501fb 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -385,6 +385,7 @@ cdef class GoldParse: self.c.tags = self.mem.alloc(len(doc), sizeof(int)) self.c.heads = self.mem.alloc(len(doc), sizeof(int)) self.c.labels = self.mem.alloc(len(doc), sizeof(attr_t)) + self.c.has_dep = self.mem.alloc(len(doc), sizeof(int)) self.c.ner = self.mem.alloc(len(doc), sizeof(Transition)) self.words = [None] * len(doc) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 7a9afdd06..7df5fe081 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -60,7 +60,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cost += 1 if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)): cost += 1 - cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0 + cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0 return cost @@ -73,7 +73,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog cost += gold.heads[target] == B_i if gold.heads[B_i] == B_i or gold.heads[B_i] < target: break - if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: + if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: cost += 1 return cost @@ -84,14 +84,14 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c elif stcls.H(child) == gold.heads[child]: return 1 # Head in buffer - elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1: + elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0: return 1 else: return 0 cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: - if gold.labels[child] == -1: + if not gold.has_dep[child]: return True elif gold.heads[child] == head: return True @@ -100,9 +100,9 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil: - if gold.labels[child] == -1: + if not gold.has_dep[child]: return True - elif label == -1: + elif label == 0: return True elif gold.labels[child] == label: return True @@ -111,8 +111,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: - return gold.labels[word] == -1 or gold.heads[word] == word - + return gold.heads[word] == word or not gold.has_dep[word] cdef class Shift: @staticmethod @@ -165,7 +164,7 @@ cdef class Reduce: cost -= 1 if gold.heads[S_i] == st.S(0): cost -= 1 - if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0: + if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: cost -= 1 return cost @@ -285,9 +284,9 @@ cdef class Break: return 0 cdef int _get_root(int word, const GoldParseC* gold) nogil: - while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0: + while gold.heads[word] != word and not gold.has_dep[word] and word >= 0: word = gold.heads[word] - if gold.labels[word] == -1: + if not gold.has_dep[word]: return -1 else: return word @@ -363,9 +362,10 @@ cdef class ArcEager(TransitionSystem): for i in range(gold.length): if gold.heads[i] is None: # Missing values gold.c.heads[i] = i - gold.c.labels[i] = -1 + gold.c.has_dep[i] = False else: label = gold.labels[i] + gold.c.has_dep[i] = True if label.upper() == 'ROOT': label = 'ROOT' gold.c.heads[i] = gold.heads[i] @@ -440,18 +440,19 @@ cdef class ArcEager(TransitionSystem): cdef int set_valid(self, int* output, const StateC* st) nogil: cdef bint[N_MOVES] is_valid - is_valid[SHIFT] = Shift.is_valid(st, -1) - is_valid[REDUCE] = Reduce.is_valid(st, -1) - is_valid[LEFT] = LeftArc.is_valid(st, -1) - is_valid[RIGHT] = RightArc.is_valid(st, -1) - is_valid[BREAK] = Break.is_valid(st, -1) + is_valid[SHIFT] = Shift.is_valid(st, 0) + is_valid[REDUCE] = Reduce.is_valid(st, 0) + is_valid[LEFT] = LeftArc.is_valid(st, 0) + is_valid[RIGHT] = RightArc.is_valid(st, 0) + is_valid[BREAK] = Break.is_valid(st, 0) cdef int i for i in range(self.n_moves): output[i] = is_valid[self.c[i].move] cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, GoldParse gold) except -1: - cdef int i, move, label + cdef int i, move + cdef attr_t label cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef weight_t[N_MOVES] move_costs