diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index b0c3819b6..7ac11fd46 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -20,7 +20,7 @@ from .stateclass cimport StateClass DEF NON_MONOTONIC = True DEF USE_BREAK = False -DEF USE_ROOT_ARC_SEGMENT = True +DEF USE_ROOT_ARC_SEGMENT = False cdef weight_t MIN_SCORE = -90000 @@ -69,6 +69,7 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cost += gold.heads[target] == B_i if gold.heads[B_i] == B_i or gold.heads[B_i] < target: break + cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0 return cost @@ -244,14 +245,22 @@ cdef class Break: @staticmethod cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef int cost = 0 + cdef int S_i, B_i + for i in range(s.stack_depth()): + S_i = s.S(i) + for j in range(s.buffer_length()): + B_i = s.B(j) + cost += gold.heads[S_i] == B_i + cost += gold.heads[B_i] == S_i # Check for sentence boundary --- if it's here, we can't have any deps # between stack and buffer, so rest of action is irrelevant. s0_root = _get_root(s.S(0), gold) b0_root = _get_root(s.B(0), gold) - if s0_root == -1 or b0_root == -1 or s0_root != b0_root: - return 0 + if s0_root != b0_root or s0_root == -1 or b0_root == -1: + return cost else: - return 1 + return cost + 1 @staticmethod cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: