diff --git a/spacy/tests/parser/test_neural_parser.py b/spacy/tests/parser/test_neural_parser.py index 42b55745f..30a6367c8 100644 --- a/spacy/tests/parser/test_neural_parser.py +++ b/spacy/tests/parser/test_neural_parser.py @@ -78,3 +78,16 @@ def test_predict_doc_beam(parser, tok2vec, model, doc): parser(doc, beam_width=32, beam_density=0.001) for word in doc: print(word.text, word.head, word.dep_) + + +def test_update_doc_beam(parser, tok2vec, model, doc, gold): + parser.model = model + tokvecs, bp_tokvecs = tok2vec.begin_update([doc]) + d_tokvecs = parser.update_beam(([doc], tokvecs), [gold]) + assert d_tokvecs[0].shape == tokvecs[0].shape + def optimize(weights, gradient, key=None): + weights -= 0.001 * gradient + bp_tokvecs(d_tokvecs, sgd=optimize) + assert d_tokvecs[0].sum() == 0. + + diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py new file mode 100644 index 000000000..45c85d969 --- /dev/null +++ b/spacy/tests/parser/test_nn_beam.py @@ -0,0 +1,87 @@ +from __future__ import unicode_literals +import pytest +import numpy +from thinc.api import layerize + +from ...vocab import Vocab +from ...syntax.arc_eager import ArcEager +from ...tokens import Doc +from ...gold import GoldParse +from ...syntax._beam_utils import ParserBeam, update_beam +from ...syntax.stateclass import StateClass + + +@pytest.fixture +def vocab(): + return Vocab() + +@pytest.fixture +def moves(vocab): + aeager = ArcEager(vocab.strings, {}) + aeager.add_action(2, 'nsubj') + aeager.add_action(3, 'dobj') + aeager.add_action(2, 'aux') + return aeager + + +@pytest.fixture +def docs(vocab): + return [Doc(vocab, words=['Rats', 'bite', 'things'])] + +@pytest.fixture +def states(docs): + return [StateClass(doc) for doc in docs] + +@pytest.fixture +def tokvecs(docs, vector_size): + output = [] + for doc in docs: + vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size)) + output.append(numpy.asarray(vec)) + return output + + +@pytest.fixture +def golds(docs): + return [GoldParse(doc) for doc in docs] + + +@pytest.fixture +def batch_size(docs): + return len(docs) + + +@pytest.fixture +def beam_width(): + return 4 + + +@pytest.fixture +def vector_size(): + return 6 + + +@pytest.fixture +def beam(moves, states, golds, beam_width): + return ParserBeam(moves, states, golds, width=beam_width) + +@pytest.fixture +def scores(moves, batch_size, beam_width): + return [ + numpy.asarray( + numpy.random.uniform(-0.1, 0.1, (batch_size, moves.n_moves)), + dtype='f') + for _ in range(batch_size)] + + +def test_create_beam(beam): + pass + + +def test_beam_advance(beam, scores): + beam.advance(scores) + + +def test_beam_advance_too_few_scores(beam, scores): + with pytest.raises(IndexError): + beam.advance(scores[:-1])