mirror of https://github.com/explosion/spaCy.git
231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
import pytest
|
|
from spacy.vocab import Vocab
|
|
from spacy import registry
|
|
from spacy.training import Example
|
|
from spacy.pipeline import DependencyParser
|
|
from spacy.tokens import Doc
|
|
from spacy.pipeline._parser_internals.nonproj import projectivize
|
|
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
|
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
|
|
|
|
|
def get_sequence_costs(M, words, heads, deps, transitions):
|
|
doc = Doc(Vocab(), words=words)
|
|
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
|
|
states, golds, _ = M.init_gold_batch([example])
|
|
state = states[0]
|
|
gold = golds[0]
|
|
cost_history = []
|
|
for gold_action in transitions:
|
|
gold.update(state)
|
|
state_costs = {}
|
|
for i in range(M.n_moves):
|
|
name = M.class_name(i)
|
|
state_costs[name] = M.get_cost(state, gold, i)
|
|
M.transition(state, gold_action)
|
|
cost_history.append(state_costs)
|
|
return state, cost_history
|
|
|
|
|
|
@pytest.fixture
|
|
def vocab():
|
|
return Vocab()
|
|
|
|
|
|
@pytest.fixture
|
|
def arc_eager(vocab):
|
|
moves = ArcEager(vocab.strings, ArcEager.get_actions())
|
|
moves.add_action(2, "left")
|
|
moves.add_action(3, "right")
|
|
return moves
|
|
|
|
|
|
def test_oracle_four_words(arc_eager, vocab):
|
|
words = ["a", "b", "c", "d"]
|
|
heads = [1, 1, 3, 3]
|
|
deps = ["left", "ROOT", "left", "ROOT"]
|
|
for dep in deps:
|
|
arc_eager.add_action(2, dep) # Left
|
|
arc_eager.add_action(3, dep) # Right
|
|
actions = ["L-left", "B-ROOT", "L-left"]
|
|
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
|
assert state.is_final()
|
|
for i, state_costs in enumerate(cost_history):
|
|
# Check gold moves is 0 cost
|
|
assert state_costs[actions[i]] == 0.0, actions[i]
|
|
for other_action, cost in state_costs.items():
|
|
if other_action != actions[i]:
|
|
assert cost >= 1, (i, other_action)
|
|
|
|
|
|
annot_tuples = [
|
|
(0, "When", "WRB", 11, "advmod", "O"),
|
|
(1, "Walter", "NNP", 2, "compound", "B-PERSON"),
|
|
(2, "Rodgers", "NNP", 11, "nsubj", "L-PERSON"),
|
|
(3, ",", ",", 2, "punct", "O"),
|
|
(4, "our", "PRP$", 6, "poss", "O"),
|
|
(5, "embedded", "VBN", 6, "amod", "O"),
|
|
(6, "reporter", "NN", 2, "appos", "O"),
|
|
(7, "with", "IN", 6, "prep", "O"),
|
|
(8, "the", "DT", 10, "det", "B-ORG"),
|
|
(9, "3rd", "NNP", 10, "compound", "I-ORG"),
|
|
(10, "Cavalry", "NNP", 7, "pobj", "L-ORG"),
|
|
(11, "says", "VBZ", 44, "advcl", "O"),
|
|
(12, "three", "CD", 13, "nummod", "U-CARDINAL"),
|
|
(13, "battalions", "NNS", 16, "nsubj", "O"),
|
|
(14, "of", "IN", 13, "prep", "O"),
|
|
(15, "troops", "NNS", 14, "pobj", "O"),
|
|
(16, "are", "VBP", 11, "ccomp", "O"),
|
|
(17, "on", "IN", 16, "prep", "O"),
|
|
(18, "the", "DT", 19, "det", "O"),
|
|
(19, "ground", "NN", 17, "pobj", "O"),
|
|
(20, ",", ",", 17, "punct", "O"),
|
|
(21, "inside", "IN", 17, "prep", "O"),
|
|
(22, "Baghdad", "NNP", 21, "pobj", "U-GPE"),
|
|
(23, "itself", "PRP", 22, "appos", "O"),
|
|
(24, ",", ",", 16, "punct", "O"),
|
|
(25, "have", "VBP", 26, "aux", "O"),
|
|
(26, "taken", "VBN", 16, "dep", "O"),
|
|
(27, "up", "RP", 26, "prt", "O"),
|
|
(28, "positions", "NNS", 26, "dobj", "O"),
|
|
(29, "they", "PRP", 31, "nsubj", "O"),
|
|
(30, "'re", "VBP", 31, "aux", "O"),
|
|
(31, "going", "VBG", 26, "parataxis", "O"),
|
|
(32, "to", "TO", 33, "aux", "O"),
|
|
(33, "spend", "VB", 31, "xcomp", "O"),
|
|
(34, "the", "DT", 35, "det", "B-TIME"),
|
|
(35, "night", "NN", 33, "dobj", "L-TIME"),
|
|
(36, "there", "RB", 33, "advmod", "O"),
|
|
(37, "presumably", "RB", 33, "advmod", "O"),
|
|
(38, ",", ",", 44, "punct", "O"),
|
|
(39, "how", "WRB", 40, "advmod", "O"),
|
|
(40, "many", "JJ", 41, "amod", "O"),
|
|
(41, "soldiers", "NNS", 44, "pobj", "O"),
|
|
(42, "are", "VBP", 44, "aux", "O"),
|
|
(43, "we", "PRP", 44, "nsubj", "O"),
|
|
(44, "talking", "VBG", 44, "ROOT", "O"),
|
|
(45, "about", "IN", 44, "prep", "O"),
|
|
(46, "right", "RB", 47, "advmod", "O"),
|
|
(47, "now", "RB", 44, "advmod", "O"),
|
|
(48, "?", ".", 44, "punct", "O"),
|
|
]
|
|
|
|
|
|
def test_get_oracle_actions():
|
|
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
|
|
for id_, word, tag, head, dep, ent in annot_tuples:
|
|
ids.append(id_)
|
|
words.append(word)
|
|
tags.append(tag)
|
|
heads.append(head)
|
|
deps.append(dep)
|
|
ents.append(ent)
|
|
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
|
|
config = {
|
|
"learn_tokens": False,
|
|
"min_action_freq": 0,
|
|
"update_with_oracle_cut_size": 100,
|
|
}
|
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
|
parser = DependencyParser(doc.vocab, model, **config)
|
|
parser.moves.add_action(0, "")
|
|
parser.moves.add_action(1, "")
|
|
parser.moves.add_action(1, "")
|
|
parser.moves.add_action(4, "ROOT")
|
|
heads, deps = projectivize(heads, deps)
|
|
for i, (head, dep) in enumerate(zip(heads, deps)):
|
|
if head > i:
|
|
parser.moves.add_action(2, dep)
|
|
elif head < i:
|
|
parser.moves.add_action(3, dep)
|
|
example = Example.from_dict(
|
|
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
|
|
)
|
|
parser.moves.get_oracle_sequence(example)
|
|
|
|
|
|
def test_oracle_dev_sentence(vocab, arc_eager):
|
|
words_deps_heads = """
|
|
Rolls-Royce nn Inc.
|
|
Motor nn Inc.
|
|
Cars nn Inc.
|
|
Inc. nsubj said
|
|
said ROOT said
|
|
it nsubj expects
|
|
expects ccomp said
|
|
its poss sales
|
|
U.S. nn sales
|
|
sales nsubj steady
|
|
to aux steady
|
|
remain cop steady
|
|
steady xcomp expects
|
|
at prep steady
|
|
about quantmod 1,200
|
|
1,200 num cars
|
|
cars pobj at
|
|
in prep steady
|
|
1990 pobj in
|
|
. punct said
|
|
"""
|
|
expected_transitions = [
|
|
"S", # Shift 'Motor'
|
|
"S", # Shift 'Cars'
|
|
"L-nn", # Attach 'Cars' to 'Inc.'
|
|
"L-nn", # Attach 'Motor' to 'Inc.'
|
|
"L-nn", # Attach 'Rolls-Royce' to 'Inc.', force shift
|
|
"L-nsubj", # Attach 'Inc.' to 'said'
|
|
"S", # Shift 'it'
|
|
"L-nsubj", # Attach 'it.' to 'expects'
|
|
"R-ccomp", # Attach 'expects' to 'said'
|
|
"S", # Shift 'its'
|
|
"S", # Shift 'U.S.'
|
|
"L-nn", # Attach 'U.S.' to 'sales'
|
|
"L-poss", # Attach 'its' to 'sales'
|
|
"S", # Shift 'sales'
|
|
"S", # Shift 'to'
|
|
"S", # Shift 'remain'
|
|
"L-cop", # Attach 'remain' to 'steady'
|
|
"L-aux", # Attach 'to' to 'steady'
|
|
"L-nsubj", # Attach 'sales' to 'steady'
|
|
"R-xcomp", # Attach 'steady' to 'expects'
|
|
"R-prep", # Attach 'at' to 'steady'
|
|
"S", # Shift 'about'
|
|
"L-quantmod", # Attach "about" to "1,200"
|
|
"S", # Shift "1,200"
|
|
"L-num", # Attach "1,200" to "cars"
|
|
"R-pobj", # Attach "cars" to "at"
|
|
"D", # Reduce "cars"
|
|
"D", # Reduce "at"
|
|
"R-prep", # Attach "in" to "steady"
|
|
"R-pobj", # Attach "1990" to "in"
|
|
"D", # Reduce "1990"
|
|
"D", # Reduce "in"
|
|
"D", # Reduce "steady"
|
|
"D", # Reduce "expects"
|
|
"R-punct", # Attach "." to "said"
|
|
]
|
|
|
|
gold_words = []
|
|
gold_deps = []
|
|
gold_heads = []
|
|
for line in words_deps_heads.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
word, dep, head = line.split()
|
|
gold_words.append(word)
|
|
gold_deps.append(dep)
|
|
gold_heads.append(head)
|
|
gold_heads = [gold_words.index(head) for head in gold_heads]
|
|
for dep in gold_deps:
|
|
arc_eager.add_action(2, dep) # Left
|
|
arc_eager.add_action(3, dep) # Right
|
|
|
|
doc = Doc(Vocab(), words=gold_words)
|
|
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
|
|
|
|
ae_oracle_actions = arc_eager.get_oracle_sequence(example)
|
|
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
|
|
assert ae_oracle_actions == expected_transitions
|