Fix arc-eager oracle tests

This commit is contained in:
Matthew Honnibal 2018-05-01 16:16:14 +02:00
parent 31ed64e9b0
commit b43bfd3524
1 changed files with 2 additions and 98 deletions

View File

@ -54,76 +54,12 @@ def gold(doc, words):
else:
raise NotImplementedError
def test_shift_is_gold_at_first_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
assert arc_eager.get_cost(state, gold, 'S') == 0
def test_reduce_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'D') != 0
def test_break_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'B-ROOT') != 0
def test_left_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'L-left') != 0
def test_right_is_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'R-right') == 0
def test_reduce_is_gold_at_third_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
assert arc_eager.get_cost(state, gold, 'D') == 0
def test_cant_arc_is_gold_at_third_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
assert not state.can_arc()
def test_fourth_state_is_final(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
arc_eager.transition(state, 'D')
assert state.is_final()
def test_oracle_sequence_two_words(arc_eager, doc, gold):
parser = DependencyParser(doc.vocab, moves=arc_eager)
state = StateClass(doc)
parser.moves.preprocess_gold(gold)
actions = parser.moves.get_oracle_sequence(doc, gold)
names = [parser.moves.class_name(i) for i in actions]
assert names == ['S', 'R-right', 'D']
@pytest.mark.xfail
def test_oracle_four_words(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'ROOT', 'left', 'ROOT']
actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S']
actions = ['L-left', 'B-ROOT', 'L-left']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final()
for i, state_costs in enumerate(cost_history):
@ -133,38 +69,6 @@ def test_oracle_four_words(arc_eager, vocab):
if other_action != actions[i]:
assert cost >= 1
def test_non_monotonic_sequence_four_words(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final()
c0 = cost_history.pop(0)
assert c0['S'] == 0.0
c1 = cost_history.pop(0)
assert c1['L-left'] == 0.0
assert c1['R-right'] != 0.0
c2 = cost_history.pop(0)
assert c2['R-right'] != 0.0
assert c2['B-ROOT'] == 0.0
assert c2['D'] == 0.0
c3 = cost_history.pop(0)
assert c3['L-left'] == -1.0
def test_reduce_is_gold_at_break(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final(), state.print_state(words)
c0 = cost_history.pop(0)
c1 = cost_history.pop(0)
c2 = cost_history.pop(0)
c3 = cost_history.pop(0)
assert c3['D'] == 0.0
annot_tuples = [
(0, 'When', 'WRB', 11, 'advmod', 'O'),