From b438cfddbccae254c7b409fd8c54d2d5bf9980db Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Jan 2017 17:51:46 +0100 Subject: [PATCH] Modernise matcher tests and split into two files --- spacy/tests/matcher/test_entity_id.py | 74 ++++++++++++--------------- spacy/tests/matcher/test_matcher.py | 15 ++++++ 2 files changed, 49 insertions(+), 40 deletions(-) create mode 100644 spacy/tests/matcher/test_matcher.py diff --git a/spacy/tests/matcher/test_entity_id.py b/spacy/tests/matcher/test_entity_id.py index b55db36af..9982a3f44 100644 --- a/spacy/tests/matcher/test_entity_id.py +++ b/spacy/tests/matcher/test_entity_id.py @@ -1,59 +1,53 @@ +# coding: utf-8 from __future__ import unicode_literals -import spacy -from spacy.vocab import Vocab -from spacy.matcher import Matcher -from spacy.tokens.doc import Doc -from spacy.attrs import * + +from ...matcher import Matcher +from ...attrs import ORTH +from ..util import get_doc import pytest -@pytest.fixture -def en_vocab(): - return spacy.get_lang_class('en').Defaults.create_vocab() - - -def test_init_matcher(en_vocab): +@pytest.mark.parametrize('words,entity', [ + (["Test", "Entity"], "TestEntity")]) +def test_matcher_add_empty_entity(en_vocab, words, entity): matcher = Matcher(en_vocab) + matcher.add_entity(entity) + doc = get_doc(en_vocab, words) assert matcher.n_patterns == 0 - assert matcher(Doc(en_vocab, words=[u'Some', u'words'])) == [] + assert matcher(doc) == [] -def test_add_empty_entity(en_vocab): +@pytest.mark.parametrize('entity1,entity2,attrs', [ + ("TestEntity", "TestEntity2", {"Hello": "World"})]) +def test_matcher_get_entity_attrs(en_vocab, entity1, entity2, attrs): matcher = Matcher(en_vocab) - matcher.add_entity('TestEntity') + matcher.add_entity(entity1) + assert matcher.get_entity(entity1) == {} + matcher.add_entity(entity2, attrs=attrs) + assert matcher.get_entity(entity2) == attrs + assert matcher.get_entity(entity1) == {} + + +@pytest.mark.parametrize('words,entity,attrs', + [(["Test", "Entity"], "TestEntity", {"Hello": "World"})]) +def test_matcher_get_entity_via_match(en_vocab, words, entity, attrs): + matcher = Matcher(en_vocab) + matcher.add_entity(entity, attrs=attrs) + doc = get_doc(en_vocab, words) assert matcher.n_patterns == 0 - assert matcher(Doc(en_vocab, words=[u'Test', u'Entity'])) == [] + assert matcher(doc) == [] - -def test_get_entity_attrs(en_vocab): - matcher = Matcher(en_vocab) - matcher.add_entity('TestEntity') - entity = matcher.get_entity('TestEntity') - assert entity == {} - matcher.add_entity('TestEntity2', attrs={'Hello': 'World'}) - entity = matcher.get_entity('TestEntity2') - assert entity == {'Hello': 'World'} - assert matcher.get_entity('TestEntity') == {} - - -def test_get_entity_via_match(en_vocab): - matcher = Matcher(en_vocab) - matcher.add_entity('TestEntity', attrs={u'Hello': u'World'}) - assert matcher.n_patterns == 0 - assert matcher(Doc(en_vocab, words=[u'Test', u'Entity'])) == [] - matcher.add_pattern(u'TestEntity', [{ORTH: u'Test'}, {ORTH: u'Entity'}]) + matcher.add_pattern(entity, [{ORTH: words[0]}, {ORTH: words[1]}]) assert matcher.n_patterns == 1 - matches = matcher(Doc(en_vocab, words=[u'Test', u'Entity'])) + + matches = matcher(doc) assert len(matches) == 1 assert len(matches[0]) == 4 + ent_id, label, start, end = matches[0] - assert ent_id == matcher.vocab.strings[u'TestEntity'] + assert ent_id == matcher.vocab.strings[entity] assert label == 0 assert start == 0 assert end == 2 - attrs = matcher.get_entity(ent_id) - assert attrs == {u'Hello': u'World'} - - - + assert matcher.get_entity(ent_id) == attrs diff --git a/spacy/tests/matcher/test_matcher.py b/spacy/tests/matcher/test_matcher.py new file mode 100644 index 000000000..1b75f4f92 --- /dev/null +++ b/spacy/tests/matcher/test_matcher.py @@ -0,0 +1,15 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from ...matcher import Matcher +from ..util import get_doc + +import pytest + + +@pytest.mark.parametrize('words', [["Some", "words"]]) +def test_matcher_init(en_vocab, words): + matcher = Matcher(en_vocab) + doc = get_doc(en_vocab, words) + assert matcher.n_patterns == 0 + assert matcher(doc) == []