mirror of https://github.com/explosion/spaCy.git
ensure Span.as_doc keeps the entity links + unit test
This commit is contained in:
parent
58a5b40ef6
commit
8608685543
|
@ -82,6 +82,7 @@ cdef enum attr_id_t:
|
|||
DEP
|
||||
ENT_IOB
|
||||
ENT_TYPE
|
||||
ENT_KB_ID
|
||||
HEAD
|
||||
SENT_START
|
||||
SPACY
|
||||
|
|
|
@ -84,6 +84,7 @@ IDS = {
|
|||
"DEP": DEP,
|
||||
"ENT_IOB": ENT_IOB,
|
||||
"ENT_TYPE": ENT_TYPE,
|
||||
"ENT_KB_ID": ENT_KB_ID,
|
||||
"HEAD": HEAD,
|
||||
"SENT_START": SENT_START,
|
||||
"SPACY": SPACY,
|
||||
|
|
|
@ -81,6 +81,7 @@ cdef enum symbol_t:
|
|||
DEP
|
||||
ENT_IOB
|
||||
ENT_TYPE
|
||||
ENT_KB_ID
|
||||
HEAD
|
||||
SENT_START
|
||||
SPACY
|
||||
|
|
|
@ -86,6 +86,7 @@ IDS = {
|
|||
"DEP": DEP,
|
||||
"ENT_IOB": ENT_IOB,
|
||||
"ENT_TYPE": ENT_TYPE,
|
||||
"ENT_KB_ID": ENT_KB_ID,
|
||||
"HEAD": HEAD,
|
||||
"SENT_START": SENT_START,
|
||||
"SPACY": SPACY,
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline import EntityRuler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -101,3 +102,44 @@ def test_candidate_generation(nlp):
|
|||
assert(len(mykb.get_candidates('douglas')) == 2)
|
||||
assert(len(mykb.get_candidates('adam')) == 1)
|
||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
||||
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
||||
|
||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||
sentencizer = nlp.create_pipe("sentencizer")
|
||||
nlp.add_pipe(sentencizer)
|
||||
|
||||
ruler = EntityRuler(nlp)
|
||||
patterns = [{"label": "GPE", "pattern": "Boston"},
|
||||
{"label": "GPE", "pattern": "Denver"}]
|
||||
ruler.add_patterns(patterns)
|
||||
nlp.add_pipe(ruler)
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={})
|
||||
el_pipe.set_kb(mykb)
|
||||
el_pipe.begin_training()
|
||||
el_pipe.context_weight = 0
|
||||
el_pipe.prior_weight = 1
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
||||
# test whether the entity links are preserved by the `as_doc()` function
|
||||
text = "She lives in Boston. He lives in Denver."
|
||||
doc = nlp(text)
|
||||
for ent in doc.ents:
|
||||
orig_text = ent.text
|
||||
orig_kb_id = ent.kb_id_
|
||||
sent_doc = ent.sent.as_doc()
|
||||
for s_ent in sent_doc.ents:
|
||||
if s_ent.text == orig_text:
|
||||
assert s_ent.kb_id_ == orig_kb_id
|
||||
|
|
|
@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
|||
from ..typedefs cimport attr_t, flags_t
|
||||
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
||||
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
||||
from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t
|
||||
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
|
||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||
|
||||
from ..attrs import intify_attrs, IDS
|
||||
|
@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
|||
return token.ent_iob
|
||||
elif feat_name == ENT_TYPE:
|
||||
return token.ent_type
|
||||
elif feat_name == ENT_KB_ID:
|
||||
return token.ent_kb_id
|
||||
else:
|
||||
return Lexeme.get_struct_attr(token.lex, feat_name)
|
||||
|
||||
|
@ -850,7 +852,7 @@ cdef class Doc:
|
|||
|
||||
DOCS: https://spacy.io/api/doc#to_bytes
|
||||
"""
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
|
||||
if self.is_tagged:
|
||||
array_head.append(TAG)
|
||||
# If doc parsed add head and dep attribute
|
||||
|
@ -1004,6 +1006,7 @@ cdef class Doc:
|
|||
"""
|
||||
cdef unicode tag, lemma, ent_type
|
||||
deprecation_warning(Warnings.W013.format(obj="Doc"))
|
||||
# TODO: ENT_KB_ID ?
|
||||
if len(args) == 3:
|
||||
deprecation_warning(Warnings.W003)
|
||||
tag, lemma, ent_type = args
|
||||
|
|
|
@ -210,7 +210,7 @@ cdef class Span:
|
|||
words = [t.text for t in self]
|
||||
spaces = [bool(t.whitespace_) for t in self]
|
||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
|
||||
if self.doc.is_tagged:
|
||||
array_head.append(TAG)
|
||||
# If doc parsed add head and dep attribute
|
||||
|
|
|
@ -53,6 +53,8 @@ cdef class Token:
|
|||
return token.ent_iob
|
||||
elif feat_name == ENT_TYPE:
|
||||
return token.ent_type
|
||||
elif feat_name == ENT_KB_ID:
|
||||
return token.ent_kb_id
|
||||
elif feat_name == SENT_START:
|
||||
return token.sent_start
|
||||
else:
|
||||
|
@ -79,5 +81,7 @@ cdef class Token:
|
|||
token.ent_iob = value
|
||||
elif feat_name == ENT_TYPE:
|
||||
token.ent_type = value
|
||||
elif feat_name == ENT_KB_ID:
|
||||
token.ent_kb_id = value
|
||||
elif feat_name == SENT_START:
|
||||
token.sent_start = value
|
||||
|
|
Loading…
Reference in New Issue