diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 31bfbe56d..7f19ab455 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -8,7 +8,8 @@ from spacy.attrs import NORM from spacy.gold import GoldParse from spacy.vocab import Vocab from spacy.tokens import Doc -from spacy.pipeline import DependencyParser +from spacy.pipeline import DependencyParser, EntityRecognizer +from spacy.util import fix_random_seed @pytest.fixture @@ -19,18 +20,6 @@ def vocab(): @pytest.fixture def parser(vocab): parser = DependencyParser(vocab) - parser.cfg["token_vector_width"] = 8 - parser.cfg["hidden_width"] = 30 - parser.cfg["hist_size"] = 0 - parser.add_label("left") - parser.begin_training([], **parser.cfg) - sgd = Adam(NumpyOps(), 0.001) - - for i in range(10): - losses = {} - doc = Doc(vocab, words=["a", "b", "c", "d"]) - gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"]) - parser.update([doc], [gold], sgd=sgd, losses=losses) return parser @@ -38,10 +27,22 @@ def test_init_parser(parser): pass -# TODO: This is flakey, because it depends on what the parser first learns. -# TODO: This now seems to be implicated in segfaults. Not sure what's up! -@pytest.mark.skip +def _train_parser(parser): + fix_random_seed(1) + parser.add_label("left") + parser.begin_training([], **parser.cfg) + sgd = Adam(NumpyOps(), 0.001) + + for i in range(10): + losses = {} + doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) + gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"]) + parser.update([doc], [gold], sgd=sgd, losses=losses) + return parser + + def test_add_label(parser): + parser = _train_parser(parser) doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = parser(doc) assert doc[0].head.i == 1 @@ -69,3 +70,16 @@ def test_add_label(parser): doc = parser(doc) assert doc[0].dep_ == "right" assert doc[2].dep_ == "left" + + +@pytest.mark.xfail +def test_add_label_deserializes_correctly(): + ner1 = EntityRecognizer(Vocab()) + ner1.add_label("C") + ner1.add_label("B") + ner1.add_label("A") + ner1.begin_training([]) + ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes()) + assert ner1.moves.n_moves == ner2.moves.n_moves + for i in range(ner1.moves.n_moves): + assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)