diff --git a/tests/test_ner.py b/tests/test_ner.py index 1523e02d8..bc492c9bf 100644 --- a/tests/test_ner.py +++ b/tests/test_ner.py @@ -6,18 +6,15 @@ import pytest @pytest.fixture def labels(): - return ['LOC', 'MISC', 'ORG', 'PER'] + ent_types = ['LOC', 'MISC', 'ORG', 'PER'] + moves = ['B', 'I', 'L', 'U'] + labels = ['NULL', 'EOL', 'O'] + for move in moves: + for ent_type in ent_types: + labels.append('%s-%s' % (move, ent_type)) + return labels -def test_n_moves(labels): - s = PyState(labels, 5) - b_moves = len(labels) - i_moves = len(labels) - l_moves = len(labels) - u_moves = len(labels) - o_moves = 1 - assert s.n_classes == b_moves + i_moves + l_moves + u_moves + o_moves - @pytest.fixture def sentence(): return "Ms. Haag plays Elianti .".split() @@ -35,7 +32,7 @@ def test_begin(state, sentence): assert state.n_ents == 0 assert state.i == 1 assert state.open_entity - assert state.ent == {'start': 0, 'label': 3, 'end': 0} + assert state.ent == {'start': 0, 'label': 4, 'end': 0} assert state.is_valid('I-PER') assert not state.is_valid('I-LOC') assert state.is_valid('L-PER')