diff --git a/spacy/attrs.pxd b/spacy/attrs.pxd index 79a177ba9..c5ba8d765 100644 --- a/spacy/attrs.pxd +++ b/spacy/attrs.pxd @@ -82,6 +82,7 @@ cdef enum attr_id_t: DEP ENT_IOB ENT_TYPE + ENT_KB_ID HEAD SENT_START SPACY diff --git a/spacy/attrs.pyx b/spacy/attrs.pyx index ed1f39a3f..8eeea363f 100644 --- a/spacy/attrs.pyx +++ b/spacy/attrs.pyx @@ -84,6 +84,7 @@ IDS = { "DEP": DEP, "ENT_IOB": ENT_IOB, "ENT_TYPE": ENT_TYPE, + "ENT_KB_ID": ENT_KB_ID, "HEAD": HEAD, "SENT_START": SENT_START, "SPACY": SPACY, diff --git a/spacy/symbols.pxd b/spacy/symbols.pxd index 051b92edb..4501861a2 100644 --- a/spacy/symbols.pxd +++ b/spacy/symbols.pxd @@ -81,6 +81,7 @@ cdef enum symbol_t: DEP ENT_IOB ENT_TYPE + ENT_KB_ID HEAD SENT_START SPACY diff --git a/spacy/symbols.pyx b/spacy/symbols.pyx index 949621820..b65ae9628 100644 --- a/spacy/symbols.pyx +++ b/spacy/symbols.pyx @@ -86,6 +86,7 @@ IDS = { "DEP": DEP, "ENT_IOB": ENT_IOB, "ENT_TYPE": ENT_TYPE, + "ENT_KB_ID": ENT_KB_ID, "HEAD": HEAD, "SENT_START": SENT_START, "SPACY": SPACY, diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index b12ad3917..7ea893408 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -5,6 +5,7 @@ import pytest from spacy.kb import KnowledgeBase from spacy.lang.en import English +from spacy.pipeline import EntityRuler @pytest.fixture @@ -101,3 +102,44 @@ def test_candidate_generation(nlp): assert(len(mykb.get_candidates('douglas')) == 2) assert(len(mykb.get_candidates('adam')) == 1) assert(len(mykb.get_candidates('shrubbery')) == 0) + + +def test_preserving_links_asdoc(nlp): + """Test that Span.as_doc preserves the existing entity links""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1]) + + # adding aliases + mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7]) + mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6]) + + # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) + sentencizer = nlp.create_pipe("sentencizer") + nlp.add_pipe(sentencizer) + + ruler = EntityRuler(nlp) + patterns = [{"label": "GPE", "pattern": "Boston"}, + {"label": "GPE", "pattern": "Denver"}] + ruler.add_patterns(patterns) + nlp.add_pipe(ruler) + + el_pipe = nlp.create_pipe(name='entity_linker', config={}) + el_pipe.set_kb(mykb) + el_pipe.begin_training() + el_pipe.context_weight = 0 + el_pipe.prior_weight = 1 + nlp.add_pipe(el_pipe, last=True) + + # test whether the entity links are preserved by the `as_doc()` function + text = "She lives in Boston. He lives in Denver." + doc = nlp(text) + for ent in doc.ents: + orig_text = ent.text + orig_kb_id = ent.kb_id_ + sent_doc = ent.sent.as_doc() + for s_ent in sent_doc.ents: + if s_ent.text == orig_text: + assert s_ent.kb_id_ == orig_kb_id diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 131c43d37..10f57ed60 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..typedefs cimport attr_t, flags_t from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB -from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t +from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t from ..attrs import intify_attrs, IDS @@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil: return token.ent_iob elif feat_name == ENT_TYPE: return token.ent_type + elif feat_name == ENT_KB_ID: + return token.ent_kb_id else: return Lexeme.get_struct_attr(token.lex, feat_name) @@ -850,7 +852,7 @@ cdef class Doc: DOCS: https://spacy.io/api/doc#to_bytes """ - array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] + array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ? if self.is_tagged: array_head.append(TAG) # If doc parsed add head and dep attribute @@ -1004,6 +1006,7 @@ cdef class Doc: """ cdef unicode tag, lemma, ent_type deprecation_warning(Warnings.W013.format(obj="Doc")) + # TODO: ENT_KB_ID ? if len(args) == 3: deprecation_warning(Warnings.W003) tag, lemma, ent_type = args diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 97b6a1adc..3f4f4418b 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -210,7 +210,7 @@ cdef class Span: words = [t.text for t in self] spaces = [bool(t.whitespace_) for t in self] cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) - array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] + array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID] if self.doc.is_tagged: array_head.append(TAG) # If doc parsed add head and dep attribute diff --git a/spacy/tokens/token.pxd b/spacy/tokens/token.pxd index bb9f7d070..ec5df3fac 100644 --- a/spacy/tokens/token.pxd +++ b/spacy/tokens/token.pxd @@ -53,6 +53,8 @@ cdef class Token: return token.ent_iob elif feat_name == ENT_TYPE: return token.ent_type + elif feat_name == ENT_KB_ID: + return token.ent_kb_id elif feat_name == SENT_START: return token.sent_start else: @@ -79,5 +81,7 @@ cdef class Token: token.ent_iob = value elif feat_name == ENT_TYPE: token.ent_type = value + elif feat_name == ENT_KB_ID: + token.ent_kb_id = value elif feat_name == SENT_START: token.sent_start = value