Fix regression test for #615 and remove unnecessary imports

This commit is contained in:
Ines Montani 2017-01-12 16:49:40 +01:00
parent aeb747e10c
commit 51ef75f629
1 changed files with 23 additions and 24 deletions

View File

@ -1,15 +1,17 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import spacy from ...matcher import Matcher
from spacy.attrs import ORTH from ...attrs import ORTH
from ..util import get_doc
def test_issue615(en_tokenizer):
def merge_phrases(matcher, doc, i, matches): def merge_phrases(matcher, doc, i, matches):
''' """Merge a phrase. We have to be careful here because we'll change the
Merge a phrase. We have to be careful here because we'll change the token indices. token indices. To avoid problems, merge all the phrases once we're called
To avoid problems, merge all the phrases once we're called on the last match. on the last match."""
'''
if i != len(matches)-1: if i != len(matches)-1:
return None return None
# Get Span objects # Get Span objects
@ -17,19 +19,16 @@ def merge_phrases(matcher, doc, i, matches):
for ent_id, label, span in spans: for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label]) span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
def test_entity_ID_assignment(): text = "The golf club is broken"
nlp = spacy.en.English() pattern = [{ ORTH: "golf"}, { ORTH: "club"}]
text = """The golf club is broken""" label = "Sport_Equipment"
doc = nlp(text)
golf_pattern = [ tokens = en_tokenizer(text)
{ ORTH: "golf"}, doc = get_doc(tokens.vocab, [t.text for t in tokens])
{ ORTH: "club"}
]
matcher = spacy.matcher.Matcher(nlp.vocab) matcher = Matcher(doc.vocab)
matcher.add_entity('Sport_Equipment', on_match = merge_phrases) matcher.add_entity(label, on_match=merge_phrases)
matcher.add_pattern("Sport_Equipment", golf_pattern, label = 'Sport_Equipment') matcher.add_pattern(label, pattern, label=label)
match = matcher(doc) match = matcher(doc)
entities = list(doc.ents) entities = list(doc.ents)