spaCy/spacy/tests/parser/test_add_label.py

71 lines
2.0 KiB
Python
Raw Normal View History

2017-10-09 01:42:35 +00:00
'''Test the ability to add a label to a (potentially trained) parsing model.'''
from __future__ import unicode_literals
import pytest
import numpy.random
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from ...attrs import NORM
from ...gold import GoldParse
from ...vocab import Vocab
from ...tokens import Doc
from ...pipeline import DependencyParser
2017-10-09 01:42:35 +00:00
numpy.random.seed(0)
@pytest.fixture
def vocab():
return Vocab(lex_attr_getters={NORM: lambda s: s})
@pytest.fixture
def parser(vocab):
parser = DependencyParser(vocab)
2017-10-10 20:57:41 +00:00
parser.cfg['token_vector_width'] = 8
parser.cfg['hidden_width'] = 30
2017-10-09 01:42:35 +00:00
parser.cfg['hist_size'] = 0
parser.add_label('left')
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
2017-10-10 20:57:41 +00:00
for i in range(10):
2017-10-09 01:42:35 +00:00
losses = {}
doc = Doc(vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['left', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
2017-10-10 20:57:41 +00:00
def test_init_parser(parser):
pass
2017-10-09 01:42:35 +00:00
def test_add_label(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
parser.add_label('right')
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['right', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].dep_ == 'right'
assert doc[2].dep_ == 'left'