Fix bug in Parser.labels and add test (#4275)

This commit is contained in:
Ines Montani 2019-09-11 18:29:35 +02:00 committed by Matthew Honnibal
parent af93997993
commit 8ebc3711dc
2 changed files with 18 additions and 1 deletions

View File

@ -1063,7 +1063,7 @@ cdef class DependencyParser(Parser):
@property @property
def labels(self): def labels(self):
# Get the labels from the model by looking at the available moves # Get the labels from the model by looking at the available moves
return tuple(set(move.split("-")[1] for move in self.move_names)) return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move))
cdef class EntityRecognizer(Parser): cdef class EntityRecognizer(Parser):

View File

@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
assert ner1.moves.n_moves == ner2.moves.n_moves assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves): for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i) assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
@pytest.mark.parametrize(
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
)
def test_add_label_get_label(pipe_cls, n_moves):
"""Test that added labels are returned correctly. This test was added to
test for a bug in DependencyParser.labels that'd cause it to fail when
splitting the move names.
"""
labels = ["A", "B", "C"]
pipe = pipe_cls(Vocab())
for label in labels:
pipe.add_label(label)
assert len(pipe.move_names) == len(labels) * n_moves
pipe_labels = sorted(list(pipe.labels))
assert pipe_labels == labels