spaCy/spacy/tests/parser/test_arc_eager_oracle.py

156 lines
5.0 KiB
Python
Raw Normal View History

💫 Refactor test suite (#2568) ## Description Related issues: #2379 (should be fixed by separating model tests) * **total execution time down from > 300 seconds to under 60 seconds** 🎉 * removed all model-specific tests that could only really be run manually anyway – those will now live in a separate test suite in the [`spacy-models`](https://github.com/explosion/spacy-models) repository and are already integrated into our new model training infrastructure * changed all relative imports to absolute imports to prepare for moving the test suite from `/spacy/tests` to `/tests` (it'll now always test against the installed version) * merged old regression tests into collections, e.g. `test_issue1001-1500.py` (about 90% of the regression tests are very short anyways) * tidied up and rewrote existing tests wherever possible ### Todo - [ ] move tests to `/tests` and adjust CI commands accordingly - [x] move model test suite from internal repo to `spacy-models` - [x] ~~investigate why `pipeline/test_textcat.py` is flakey~~ - [x] review old regression tests (leftover files) and see if they can be merged, simplified or deleted - [ ] update documentation on how to run tests ### Types of change enhancement, tests ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-07-24 21:38:44 +00:00
# coding: utf8
from __future__ import unicode_literals
2018-04-01 08:41:52 +00:00
💫 Refactor test suite (#2568) ## Description Related issues: #2379 (should be fixed by separating model tests) * **total execution time down from > 300 seconds to under 60 seconds** 🎉 * removed all model-specific tests that could only really be run manually anyway – those will now live in a separate test suite in the [`spacy-models`](https://github.com/explosion/spacy-models) repository and are already integrated into our new model training infrastructure * changed all relative imports to absolute imports to prepare for moving the test suite from `/spacy/tests` to `/tests` (it'll now always test against the installed version) * merged old regression tests into collections, e.g. `test_issue1001-1500.py` (about 90% of the regression tests are very short anyways) * tidied up and rewrote existing tests wherever possible ### Todo - [ ] move tests to `/tests` and adjust CI commands accordingly - [x] move model test suite from internal repo to `spacy-models` - [x] ~~investigate why `pipeline/test_textcat.py` is flakey~~ - [x] review old regression tests (leftover files) and see if they can be merged, simplified or deleted - [ ] update documentation on how to run tests ### Types of change enhancement, tests ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-07-24 21:38:44 +00:00
import pytest
from spacy.vocab import Vocab
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.gold import GoldParse
from spacy.syntax.nonproj import projectivize
from spacy.syntax.stateclass import StateClass
from spacy.syntax.arc_eager import ArcEager
2018-04-01 08:41:52 +00:00
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
gold = GoldParse(doc, heads=heads, deps=deps)
state = StateClass(doc)
M.preprocess_gold(gold)
cost_history = []
for gold_action in transitions:
state_costs = {}
for i in range(M.n_moves):
name = M.class_name(i)
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
return state, cost_history
@pytest.fixture
def vocab():
return Vocab()
2018-04-01 08:41:52 +00:00
@pytest.fixture
def arc_eager(vocab):
moves = ArcEager(vocab.strings, ArcEager.get_actions())
moves.add_action(2, "left")
moves.add_action(3, "right")
2018-04-01 08:41:52 +00:00
return moves
2018-04-01 08:41:52 +00:00
@pytest.fixture
def words():
return ["a", "b"]
2018-04-01 08:41:52 +00:00
@pytest.fixture
def doc(words, vocab):
if vocab is None:
vocab = Vocab()
return Doc(vocab, words=list(words))
2018-04-01 08:41:52 +00:00
@pytest.fixture
def gold(doc, words):
if len(words) == 2:
return GoldParse(doc, words=["a", "b"], heads=[0, 0], deps=["ROOT", "right"])
2018-04-01 08:41:52 +00:00
else:
raise NotImplementedError
2018-05-01 14:16:14 +00:00
@pytest.mark.xfail
2018-04-01 08:41:52 +00:00
def test_oracle_four_words(arc_eager, vocab):
words = ["a", "b", "c", "d"]
2018-04-01 08:41:52 +00:00
heads = [1, 1, 3, 3]
deps = ["left", "ROOT", "left", "ROOT"]
actions = ["L-left", "B-ROOT", "L-left"]
2018-04-01 08:41:52 +00:00
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
assert state_costs[actions[i]] == 0.0, actions[i]
for other_action, cost in state_costs.items():
if other_action != actions[i]:
assert cost >= 1
annot_tuples = [
(0, "When", "WRB", 11, "advmod", "O"),
(1, "Walter", "NNP", 2, "compound", "B-PERSON"),
(2, "Rodgers", "NNP", 11, "nsubj", "L-PERSON"),
(3, ",", ",", 2, "punct", "O"),
(4, "our", "PRP$", 6, "poss", "O"),
(5, "embedded", "VBN", 6, "amod", "O"),
(6, "reporter", "NN", 2, "appos", "O"),
(7, "with", "IN", 6, "prep", "O"),
(8, "the", "DT", 10, "det", "B-ORG"),
(9, "3rd", "NNP", 10, "compound", "I-ORG"),
(10, "Cavalry", "NNP", 7, "pobj", "L-ORG"),
(11, "says", "VBZ", 44, "advcl", "O"),
(12, "three", "CD", 13, "nummod", "U-CARDINAL"),
(13, "battalions", "NNS", 16, "nsubj", "O"),
(14, "of", "IN", 13, "prep", "O"),
(15, "troops", "NNS", 14, "pobj", "O"),
(16, "are", "VBP", 11, "ccomp", "O"),
(17, "on", "IN", 16, "prep", "O"),
(18, "the", "DT", 19, "det", "O"),
(19, "ground", "NN", 17, "pobj", "O"),
(20, ",", ",", 17, "punct", "O"),
(21, "inside", "IN", 17, "prep", "O"),
(22, "Baghdad", "NNP", 21, "pobj", "U-GPE"),
(23, "itself", "PRP", 22, "appos", "O"),
(24, ",", ",", 16, "punct", "O"),
(25, "have", "VBP", 26, "aux", "O"),
(26, "taken", "VBN", 16, "dep", "O"),
(27, "up", "RP", 26, "prt", "O"),
(28, "positions", "NNS", 26, "dobj", "O"),
(29, "they", "PRP", 31, "nsubj", "O"),
(30, "'re", "VBP", 31, "aux", "O"),
(31, "going", "VBG", 26, "parataxis", "O"),
(32, "to", "TO", 33, "aux", "O"),
(33, "spend", "VB", 31, "xcomp", "O"),
(34, "the", "DT", 35, "det", "B-TIME"),
(35, "night", "NN", 33, "dobj", "L-TIME"),
(36, "there", "RB", 33, "advmod", "O"),
(37, "presumably", "RB", 33, "advmod", "O"),
(38, ",", ",", 44, "punct", "O"),
(39, "how", "WRB", 40, "advmod", "O"),
(40, "many", "JJ", 41, "amod", "O"),
(41, "soldiers", "NNS", 44, "pobj", "O"),
(42, "are", "VBP", 44, "aux", "O"),
(43, "we", "PRP", 44, "nsubj", "O"),
(44, "talking", "VBG", 44, "ROOT", "O"),
(45, "about", "IN", 44, "prep", "O"),
(46, "right", "RB", 47, "advmod", "O"),
(47, "now", "RB", 44, "advmod", "O"),
(48, "?", ".", 44, "punct", "O"),
]
def test_get_oracle_actions():
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
for id_, word, tag, head, dep, ent in annot_tuples:
ids.append(id_)
words.append(word)
tags.append(tag)
heads.append(head)
deps.append(dep)
ents.append(ent)
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
parser = DependencyParser(doc.vocab)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT")
for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i:
parser.moves.add_action(2, dep)
elif head < i:
parser.moves.add_action(3, dep)
heads, deps = projectivize(heads, deps)
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
parser.moves.preprocess_gold(gold)
2018-11-30 16:43:08 +00:00
parser.moves.get_oracle_sequence(doc, gold)