From 8ebc3711dc1ec065c39aeb6017d9ace129a28d3f Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Wed, 11 Sep 2019 18:29:35 +0200 Subject: [PATCH] Fix bug in Parser.labels and add test (#4275) --- spacy/pipeline/pipes.pyx | 2 +- spacy/tests/parser/test_add_label.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 90ccc2fbf..095021f00 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1063,7 +1063,7 @@ cdef class DependencyParser(Parser): @property def labels(self): # Get the labels from the model by looking at the available moves - return tuple(set(move.split("-")[1] for move in self.move_names)) + return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move)) cdef class EntityRecognizer(Parser): diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 45a51ac8e..4ab9c1e70 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly(): assert ner1.moves.n_moves == ner2.moves.n_moves for i in range(ner1.moves.n_moves): assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i) + + +@pytest.mark.parametrize( + "pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)] +) +def test_add_label_get_label(pipe_cls, n_moves): + """Test that added labels are returned correctly. This test was added to + test for a bug in DependencyParser.labels that'd cause it to fail when + splitting the move names. + """ + labels = ["A", "B", "C"] + pipe = pipe_cls(Vocab()) + for label in labels: + pipe.add_label(label) + assert len(pipe.move_names) == len(labels) * n_moves + pipe_labels = sorted(list(pipe.labels)) + assert pipe_labels == labels