diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 5f6e1f303..5242452b6 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -97,7 +97,7 @@ cdef inline bint at_eol(const State *s) nogil: cdef inline bint is_final(const State *s) nogil: - return at_eol(s) # The stack will be attached to root anyway + return at_eol(s) and s.stack_len < 2 cdef int children_in_buffer(const State *s, const int head, const int* gold) except -1 diff --git a/spacy/syntax/_state.pyx b/spacy/syntax/_state.pyx index b2dbe3772..12295905b 100644 --- a/spacy/syntax/_state.pyx +++ b/spacy/syntax/_state.pyx @@ -35,11 +35,6 @@ cdef int push_stack(State *s) except -1: s.stack[0] = s.i s.stack_len += 1 s.i += 1 - if at_eol(s): - while s.stack_len != 0: - if not has_head(get_s0(s)): - get_s0(s).dep = 0 - pop_stack(s) cdef int children_in_buffer(const State *s, int head, const int* gold) except -1: diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 78aaec60f..fb544aa3e 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -43,7 +43,7 @@ cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, - LEFT: {}, BREAK: {'ROOT': True}} + LEFT: {'ROOT': True}, BREAK: {'ROOT': True}} for raw_text, segmented, (ids, words, tags, heads, labels, iob) in gold_parses: for child, head, label in zip(ids, heads, labels): if label != 'ROOT': @@ -126,7 +126,11 @@ cdef int _do_shift(const Transition* self, State* state) except -1: cdef int _do_left(const Transition* self, State* state) except -1: - add_dep(state, state.i, state.stack[0], self.label) + # Interpret left-arcs from EOL as attachment to root + if at_eol(state): + add_dep(state, state.stack[0], state.stack[0], self.label) + else: + add_dep(state, state.i, state.stack[0], self.label) pop_stack(state) @@ -195,6 +199,13 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce if gold.c_heads[s.stack[0]] == s.i: cost += self.label != gold.c_labels[s.stack[0]] return cost + # If we're at EOL, then the left arc will add an arc to ROOT. + elif at_eol(s): + # Are we root? + cost += gold.c_heads[s.stack[0]] != s.stack[0] + # Are we labelling correctly? + cost += self.label != gold.c_labels[s.stack[0]] + return cost cost += head_in_buffer(s, s.stack[0], gold.c_heads) cost += children_in_buffer(s, s.stack[0], gold.c_heads)