2018-07-24 21:38:44 +00:00
|
|
|
# coding: utf8
|
2018-03-27 17:23:02 +00:00
|
|
|
from __future__ import unicode_literals
|
2018-04-01 08:41:52 +00:00
|
|
|
|
2018-07-24 21:38:44 +00:00
|
|
|
import pytest
|
|
|
|
from spacy.vocab import Vocab
|
|
|
|
from spacy.pipeline import DependencyParser
|
|
|
|
from spacy.tokens import Doc
|
|
|
|
from spacy.gold import GoldParse
|
|
|
|
from spacy.syntax.nonproj import projectivize
|
|
|
|
from spacy.syntax.stateclass import StateClass
|
|
|
|
from spacy.syntax.arc_eager import ArcEager
|
2018-04-01 08:41:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_sequence_costs(M, words, heads, deps, transitions):
|
|
|
|
doc = Doc(Vocab(), words=words)
|
|
|
|
gold = GoldParse(doc, heads=heads, deps=deps)
|
|
|
|
state = StateClass(doc)
|
|
|
|
M.preprocess_gold(gold)
|
|
|
|
cost_history = []
|
|
|
|
for gold_action in transitions:
|
|
|
|
state_costs = {}
|
|
|
|
for i in range(M.n_moves):
|
|
|
|
name = M.class_name(i)
|
|
|
|
state_costs[name] = M.get_cost(state, gold, i)
|
|
|
|
M.transition(state, gold_action)
|
|
|
|
cost_history.append(state_costs)
|
|
|
|
return state, cost_history
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def vocab():
|
|
|
|
return Vocab()
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def arc_eager(vocab):
|
|
|
|
moves = ArcEager(vocab.strings, ArcEager.get_actions())
|
|
|
|
moves.add_action(2, 'left')
|
|
|
|
moves.add_action(3, 'right')
|
|
|
|
return moves
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def words():
|
|
|
|
return ['a', 'b']
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def doc(words, vocab):
|
|
|
|
if vocab is None:
|
|
|
|
vocab = Vocab()
|
|
|
|
return Doc(vocab, words=list(words))
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def gold(doc, words):
|
|
|
|
if len(words) == 2:
|
|
|
|
return GoldParse(doc, words=['a', 'b'], heads=[0, 0], deps=['ROOT', 'right'])
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2018-05-01 14:16:14 +00:00
|
|
|
@pytest.mark.xfail
|
2018-04-01 08:41:52 +00:00
|
|
|
def test_oracle_four_words(arc_eager, vocab):
|
|
|
|
words = ['a', 'b', 'c', 'd']
|
|
|
|
heads = [1, 1, 3, 3]
|
|
|
|
deps = ['left', 'ROOT', 'left', 'ROOT']
|
2018-05-01 14:16:14 +00:00
|
|
|
actions = ['L-left', 'B-ROOT', 'L-left']
|
2018-04-01 08:41:52 +00:00
|
|
|
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
|
|
|
assert state.is_final()
|
|
|
|
for i, state_costs in enumerate(cost_history):
|
|
|
|
# Check gold moves is 0 cost
|
|
|
|
assert state_costs[actions[i]] == 0.0, actions[i]
|
|
|
|
for other_action, cost in state_costs.items():
|
|
|
|
if other_action != actions[i]:
|
|
|
|
assert cost >= 1
|
|
|
|
|
2018-03-27 17:23:02 +00:00
|
|
|
|
|
|
|
annot_tuples = [
|
|
|
|
(0, 'When', 'WRB', 11, 'advmod', 'O'),
|
|
|
|
(1, 'Walter', 'NNP', 2, 'compound', 'B-PERSON'),
|
|
|
|
(2, 'Rodgers', 'NNP', 11, 'nsubj', 'L-PERSON'),
|
|
|
|
(3, ',', ',', 2, 'punct', 'O'),
|
|
|
|
(4, 'our', 'PRP$', 6, 'poss', 'O'),
|
|
|
|
(5, 'embedded', 'VBN', 6, 'amod', 'O'),
|
|
|
|
(6, 'reporter', 'NN', 2, 'appos', 'O'),
|
|
|
|
(7, 'with', 'IN', 6, 'prep', 'O'),
|
|
|
|
(8, 'the', 'DT', 10, 'det', 'B-ORG'),
|
|
|
|
(9, '3rd', 'NNP', 10, 'compound', 'I-ORG'),
|
|
|
|
(10, 'Cavalry', 'NNP', 7, 'pobj', 'L-ORG'),
|
|
|
|
(11, 'says', 'VBZ', 44, 'advcl', 'O'),
|
|
|
|
(12, 'three', 'CD', 13, 'nummod', 'U-CARDINAL'),
|
|
|
|
(13, 'battalions', 'NNS', 16, 'nsubj', 'O'),
|
|
|
|
(14, 'of', 'IN', 13, 'prep', 'O'),
|
|
|
|
(15, 'troops', 'NNS', 14, 'pobj', 'O'),
|
|
|
|
(16, 'are', 'VBP', 11, 'ccomp', 'O'),
|
|
|
|
(17, 'on', 'IN', 16, 'prep', 'O'),
|
|
|
|
(18, 'the', 'DT', 19, 'det', 'O'),
|
|
|
|
(19, 'ground', 'NN', 17, 'pobj', 'O'),
|
|
|
|
(20, ',', ',', 17, 'punct', 'O'),
|
|
|
|
(21, 'inside', 'IN', 17, 'prep', 'O'),
|
|
|
|
(22, 'Baghdad', 'NNP', 21, 'pobj', 'U-GPE'),
|
|
|
|
(23, 'itself', 'PRP', 22, 'appos', 'O'),
|
|
|
|
(24, ',', ',', 16, 'punct', 'O'),
|
|
|
|
(25, 'have', 'VBP', 26, 'aux', 'O'),
|
|
|
|
(26, 'taken', 'VBN', 16, 'dep', 'O'),
|
|
|
|
(27, 'up', 'RP', 26, 'prt', 'O'),
|
|
|
|
(28, 'positions', 'NNS', 26, 'dobj', 'O'),
|
|
|
|
(29, 'they', 'PRP', 31, 'nsubj', 'O'),
|
|
|
|
(30, "'re", 'VBP', 31, 'aux', 'O'),
|
|
|
|
(31, 'going', 'VBG', 26, 'parataxis', 'O'),
|
|
|
|
(32, 'to', 'TO', 33, 'aux', 'O'),
|
|
|
|
(33, 'spend', 'VB', 31, 'xcomp', 'O'),
|
2018-07-24 21:38:44 +00:00
|
|
|
(34, 'the', 'DT', 35, 'det', 'B-TIME'),
|
2018-03-27 17:23:02 +00:00
|
|
|
(35, 'night', 'NN', 33, 'dobj', 'L-TIME'),
|
|
|
|
(36, 'there', 'RB', 33, 'advmod', 'O'),
|
|
|
|
(37, 'presumably', 'RB', 33, 'advmod', 'O'),
|
|
|
|
(38, ',', ',', 44, 'punct', 'O'),
|
|
|
|
(39, 'how', 'WRB', 40, 'advmod', 'O'),
|
|
|
|
(40, 'many', 'JJ', 41, 'amod', 'O'),
|
|
|
|
(41, 'soldiers', 'NNS', 44, 'pobj', 'O'),
|
|
|
|
(42, 'are', 'VBP', 44, 'aux', 'O'),
|
|
|
|
(43, 'we', 'PRP', 44, 'nsubj', 'O'),
|
|
|
|
(44, 'talking', 'VBG', 44, 'ROOT', 'O'),
|
|
|
|
(45, 'about', 'IN', 44, 'prep', 'O'),
|
|
|
|
(46, 'right', 'RB', 47, 'advmod', 'O'),
|
|
|
|
(47, 'now', 'RB', 44, 'advmod', 'O'),
|
|
|
|
(48, '?', '.', 44, 'punct', 'O')]
|
|
|
|
|
|
|
|
def test_get_oracle_actions():
|
|
|
|
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
|
|
|
|
parser = DependencyParser(doc.vocab)
|
|
|
|
parser.moves.add_action(0, '')
|
|
|
|
parser.moves.add_action(1, '')
|
|
|
|
parser.moves.add_action(1, '')
|
|
|
|
parser.moves.add_action(4, 'ROOT')
|
|
|
|
for i, (id_, word, tag, head, dep, ent) in enumerate(annot_tuples):
|
|
|
|
if head > i:
|
|
|
|
parser.moves.add_action(2, dep)
|
|
|
|
elif head < i:
|
|
|
|
parser.moves.add_action(3, dep)
|
|
|
|
ids, words, tags, heads, deps, ents = zip(*annot_tuples)
|
|
|
|
heads, deps = projectivize(heads, deps)
|
|
|
|
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
|
|
|
|
parser.moves.preprocess_gold(gold)
|
|
|
|
actions = parser.moves.get_oracle_sequence(doc, gold)
|