mirror of https://github.com/explosion/spaCy.git
* Ensure high loss for invalid moves, and fix label reading for arc-eager
This commit is contained in:
parent
f5f15a1ef2
commit
fdabd93bfb
|
@ -35,14 +35,15 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_parses):
|
def get_labels(cls, gold_parses):
|
||||||
labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True},
|
labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
|
||||||
LEFT: {0: True}, BREAK: {0: True}}
|
LEFT: {}, BREAK: {'ROOT': True}}
|
||||||
for parse in gold_parses:
|
for parse in gold_parses:
|
||||||
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
|
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
|
||||||
if head > i:
|
if label != 'ROOT':
|
||||||
labels[RIGHT][label] = True
|
if head > i:
|
||||||
else:
|
labels[RIGHT][label] = True
|
||||||
labels[LEFT][label] = True
|
elif head < i:
|
||||||
|
labels[LEFT][label] = True
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||||
|
@ -71,6 +72,8 @@ cdef class ArcEager(TransitionSystem):
|
||||||
if scores[i] > score and is_valid[self.c[i].move]:
|
if scores[i] > score and is_valid[self.c[i].move]:
|
||||||
best = self.c[i]
|
best = self.c[i]
|
||||||
score = scores[i]
|
score = scores[i]
|
||||||
|
assert best.clas < self.n_moves
|
||||||
|
assert score > MIN_SCORE
|
||||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||||
# actions
|
# actions
|
||||||
if best.move == SHIFT:
|
if best.move == SHIFT:
|
||||||
|
@ -85,7 +88,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
cdef int _do_shift(const Transition* self, State* state) except -1:
|
cdef int _do_shift(const Transition* self, State* state) except -1:
|
||||||
# Set the dep label, in case we need it after we reduce
|
# Set the dep label, in case we need it after we reduce
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
get_s0(state).dep = self.label
|
state.sent[state.i].dep = self.label
|
||||||
push_stack(state)
|
push_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +127,8 @@ do_funcs[BREAK] = _do_break
|
||||||
|
|
||||||
|
|
||||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
assert not at_eol(s)
|
if not _can_shift(s):
|
||||||
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||||
|
@ -137,7 +141,8 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
|
||||||
|
|
||||||
|
|
||||||
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
assert s.stack_len >= 1
|
if not _can_right(s):
|
||||||
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.c_heads[s.i] == s.stack[0]:
|
if gold.c_heads[s.i] == s.stack[0]:
|
||||||
cost += self.label != gold.c_labels[s.i]
|
cost += self.label != gold.c_labels[s.i]
|
||||||
|
@ -151,7 +156,8 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
|
||||||
|
|
||||||
|
|
||||||
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
assert s.stack_len >= 1
|
if not _can_left(s):
|
||||||
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.c_heads[s.stack[0]] == s.i:
|
if gold.c_heads[s.stack[0]] == s.i:
|
||||||
cost += self.label != gold.c_labels[s.stack[0]]
|
cost += self.label != gold.c_labels[s.stack[0]]
|
||||||
|
@ -166,6 +172,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
|
||||||
|
|
||||||
|
|
||||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
if not _can_reduce(s):
|
||||||
|
return 9000
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
|
@ -174,6 +182,8 @@ cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) ex
|
||||||
|
|
||||||
|
|
||||||
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
if not _can_break(s):
|
||||||
|
return 9000
|
||||||
# When we break, we Reduce all of the words on the stack.
|
# When we break, we Reduce all of the words on the stack.
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
|
|
Loading…
Reference in New Issue