* Update parser oracle for missing heads

This commit is contained in:
Matthew Honnibal 2015-05-24 20:05:58 +02:00
parent 541c62c126
commit 78487f3e66
1 changed files with 10 additions and 6 deletions

View File

@ -69,7 +69,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(gold.length): for i in range(gold.length):
if gold.heads[i] is None: # Missing values if gold.heads[i] is None: # Missing values
gold.c_heads[i] = i gold.c_heads[i] = i
gold.c_labels[i] = self.strings[''] gold.c_labels[i] = -1
else: else:
gold.c_heads[i] = gold.heads[i] gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.strings[gold.labels[i]] gold.c_labels[i] = self.strings[gold.labels[i]]
@ -252,7 +252,9 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
if gold.c_heads[s.i] == s.stack[0]: if gold.c_heads[s.i] == s.stack[0]:
cost += self.label != gold.c_labels[s.i] cost += self.label != gold.c_labels[s.i]
return cost return cost
cost += head_in_buffer(s, s.i, gold.c_heads) # This indicates missing head
if gold.c_labels[s.i] != -1:
cost += head_in_buffer(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads) cost += children_in_stack(s, s.i, gold.c_heads)
cost += head_in_stack(s, s.i, gold.c_heads) cost += head_in_stack(s, s.i, gold.c_heads)
if NON_MONOTONIC: if NON_MONOTONIC:
@ -270,16 +272,18 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
# If we're at EOL, then the left arc will add an arc to ROOT. # If we're at EOL, then the left arc will add an arc to ROOT.
elif at_eol(s): elif at_eol(s):
# Are we root? # Are we root?
cost += gold.c_heads[s.stack[0]] != s.stack[0] if gold.c_labels[s.stack[0]] != -1:
# Are we labelling correctly? cost += gold.c_heads[s.stack[0]] != s.stack[0]
cost += self.label != gold.c_labels[s.stack[0]] # Are we labelling correctly?
cost += self.label != gold.c_labels[s.stack[0]]
return cost return cost
cost += head_in_buffer(s, s.stack[0], gold.c_heads) cost += head_in_buffer(s, s.stack[0], gold.c_heads)
cost += children_in_buffer(s, s.stack[0], gold.c_heads) cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC and s.stack_len >= 2: if NON_MONOTONIC and s.stack_len >= 2:
cost += gold.c_heads[s.stack[0]] == s.stack[-1] cost += gold.c_heads[s.stack[0]] == s.stack[-1]
cost += gold.c_heads[s.stack[0]] == s.stack[0] if gold.c_labels[s.stack[0]] != -1:
cost += gold.c_heads[s.stack[0]] == s.stack[0]
return cost return cost