Fix regression test for #615 and remove unnecessary imports

This commit is contained in:
Ines Montani 2017-01-12 16:49:40 +01:00
parent aeb747e10c
commit 51ef75f629
1 changed files with 23 additions and 24 deletions

View File

@ -1,35 +1,34 @@
# coding: utf-8
from __future__ import unicode_literals
import spacy
from spacy.attrs import ORTH
from ...matcher import Matcher
from ...attrs import ORTH
from ..util import get_doc
def merge_phrases(matcher, doc, i, matches):
'''
Merge a phrase. We have to be careful here because we'll change the token indices.
To avoid problems, merge all the phrases once we're called on the last match.
'''
if i != len(matches)-1:
return None
# Get Span objects
spans = [(ent_id, label, doc[start : end]) for ent_id, label, start, end in matches]
for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
def test_issue615(en_tokenizer):
def merge_phrases(matcher, doc, i, matches):
"""Merge a phrase. We have to be careful here because we'll change the
token indices. To avoid problems, merge all the phrases once we're called
on the last match."""
def test_entity_ID_assignment():
nlp = spacy.en.English()
text = """The golf club is broken"""
doc = nlp(text)
if i != len(matches)-1:
return None
# Get Span objects
spans = [(ent_id, label, doc[start : end]) for ent_id, label, start, end in matches]
for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
golf_pattern = [
{ ORTH: "golf"},
{ ORTH: "club"}
]
text = "The golf club is broken"
pattern = [{ ORTH: "golf"}, { ORTH: "club"}]
label = "Sport_Equipment"
matcher = spacy.matcher.Matcher(nlp.vocab)
matcher.add_entity('Sport_Equipment', on_match = merge_phrases)
matcher.add_pattern("Sport_Equipment", golf_pattern, label = 'Sport_Equipment')
tokens = en_tokenizer(text)
doc = get_doc(tokens.vocab, [t.text for t in tokens])
matcher = Matcher(doc.vocab)
matcher.add_entity(label, on_match=merge_phrases)
matcher.add_pattern(label, pattern, label=label)
match = matcher(doc)
entities = list(doc.ents)