Fix regression test for #615 and remove unnecessary imports

This commit is contained in:
Ines Montani 2017-01-12 16:49:40 +01:00
parent aeb747e10c
commit 51ef75f629
1 changed files with 23 additions and 24 deletions

View File

@ -1,15 +1,17 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import spacy from ...matcher import Matcher
from spacy.attrs import ORTH from ...attrs import ORTH
from ..util import get_doc
def merge_phrases(matcher, doc, i, matches): def test_issue615(en_tokenizer):
''' def merge_phrases(matcher, doc, i, matches):
Merge a phrase. We have to be careful here because we'll change the token indices. """Merge a phrase. We have to be careful here because we'll change the
To avoid problems, merge all the phrases once we're called on the last match. token indices. To avoid problems, merge all the phrases once we're called
''' on the last match."""
if i != len(matches)-1: if i != len(matches)-1:
return None return None
# Get Span objects # Get Span objects
@ -17,19 +19,16 @@ def merge_phrases(matcher, doc, i, matches):
for ent_id, label, span in spans: for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label]) span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
def test_entity_ID_assignment(): text = "The golf club is broken"
nlp = spacy.en.English() pattern = [{ ORTH: "golf"}, { ORTH: "club"}]
text = """The golf club is broken""" label = "Sport_Equipment"
doc = nlp(text)
golf_pattern = [ tokens = en_tokenizer(text)
{ ORTH: "golf"}, doc = get_doc(tokens.vocab, [t.text for t in tokens])
{ ORTH: "club"}
]
matcher = spacy.matcher.Matcher(nlp.vocab) matcher = Matcher(doc.vocab)
matcher.add_entity('Sport_Equipment', on_match = merge_phrases) matcher.add_entity(label, on_match=merge_phrases)
matcher.add_pattern("Sport_Equipment", golf_pattern, label = 'Sport_Equipment') matcher.add_pattern(label, pattern, label=label)
match = matcher(doc) match = matcher(doc)
entities = list(doc.ents) entities = list(doc.ents)