mirror of https://github.com/explosion/spaCy.git
Default empty KB in EL component (#5872)
* EL field documentation * documentation consistent with docs * default empty KB, initialize vocab separately * formatting * add test for changing the default entity vector length * update comment
This commit is contained in:
parent
b7e3018d97
commit
82347110f5
|
@ -48,7 +48,8 @@ def main(model, output_dir=None):
|
||||||
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
|
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
|
||||||
# For simplicity, we'll just use the original vector dimension here instead.
|
# For simplicity, we'll just use the original vector dimension here instead.
|
||||||
vectors_dim = nlp.vocab.vectors.shape[1]
|
vectors_dim = nlp.vocab.vectors.shape[1]
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim)
|
kb = KnowledgeBase(entity_vector_length=vectors_dim)
|
||||||
|
kb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# set up the data
|
# set up the data
|
||||||
entity_ids = []
|
entity_ids = []
|
||||||
|
@ -95,7 +96,8 @@ def main(model, output_dir=None):
|
||||||
print("Loading vocab from", vocab_path)
|
print("Loading vocab from", vocab_path)
|
||||||
print("Loading KB from", kb_path)
|
print("Loading KB from", kb_path)
|
||||||
vocab2 = Vocab().from_disk(vocab_path)
|
vocab2 = Vocab().from_disk(vocab_path)
|
||||||
kb2 = KnowledgeBase(vocab=vocab2)
|
kb2 = KnowledgeBase(entity_vector_length=1)
|
||||||
|
kb.initialize(vocab2)
|
||||||
kb2.load_bulk(kb_path)
|
kb2.load_bulk(kb_path)
|
||||||
print()
|
print()
|
||||||
_print_kb(kb2)
|
_print_kb(kb2)
|
||||||
|
|
|
@ -374,7 +374,8 @@ class Errors:
|
||||||
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
||||||
"includes either the `text` or `tokens` key. For more info, see "
|
"includes either the `text` or `tokens` key. For more info, see "
|
||||||
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
||||||
E139 = ("Knowledge Base for component '{name}' is empty.")
|
E139 = ("Knowledge Base for component '{name}' is empty. Use the methods "
|
||||||
|
"kb.add_entity and kb.add_alias to add entries.")
|
||||||
E140 = ("The list of entities, prior probabilities and entity vectors "
|
E140 = ("The list of entities, prior probabilities and entity vectors "
|
||||||
"should be of equal length.")
|
"should be of equal length.")
|
||||||
E141 = ("Entity vectors should be of length {required} instead of the "
|
E141 = ("Entity vectors should be of length {required} instead of the "
|
||||||
|
@ -481,6 +482,8 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
|
||||||
|
"call kb.initialize()?")
|
||||||
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||||
"a string value from {expected} but got: '{arg}'")
|
"a string value from {expected} but got: '{arg}'")
|
||||||
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
|
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
|
||||||
|
|
24
spacy/kb.pyx
24
spacy/kb.pyx
|
@ -71,17 +71,25 @@ cdef class KnowledgeBase:
|
||||||
DOCS: https://spacy.io/api/kb
|
DOCS: https://spacy.io/api/kb
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, entity_vector_length=64):
|
def __init__(self, entity_vector_length):
|
||||||
self.vocab = vocab
|
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.entity_vector_length = entity_vector_length
|
self.entity_vector_length = entity_vector_length
|
||||||
|
|
||||||
self._entry_index = PreshMap()
|
self._entry_index = PreshMap()
|
||||||
self._alias_index = PreshMap()
|
self._alias_index = PreshMap()
|
||||||
|
self.vocab = None
|
||||||
|
|
||||||
|
|
||||||
|
def initialize(self, Vocab vocab):
|
||||||
|
self.vocab = vocab
|
||||||
self.vocab.strings.add("")
|
self.vocab.strings.add("")
|
||||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||||
|
|
||||||
|
def require_vocab(self):
|
||||||
|
if self.vocab is None:
|
||||||
|
raise ValueError(Errors.E946)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_vector_length(self):
|
def entity_vector_length(self):
|
||||||
"""RETURNS (uint64): length of the entity vectors"""
|
"""RETURNS (uint64): length of the entity vectors"""
|
||||||
|
@ -94,12 +102,14 @@ cdef class KnowledgeBase:
|
||||||
return len(self._entry_index)
|
return len(self._entry_index)
|
||||||
|
|
||||||
def get_entity_strings(self):
|
def get_entity_strings(self):
|
||||||
|
self.require_vocab()
|
||||||
return [self.vocab.strings[x] for x in self._entry_index]
|
return [self.vocab.strings[x] for x in self._entry_index]
|
||||||
|
|
||||||
def get_size_aliases(self):
|
def get_size_aliases(self):
|
||||||
return len(self._alias_index)
|
return len(self._alias_index)
|
||||||
|
|
||||||
def get_alias_strings(self):
|
def get_alias_strings(self):
|
||||||
|
self.require_vocab()
|
||||||
return [self.vocab.strings[x] for x in self._alias_index]
|
return [self.vocab.strings[x] for x in self._alias_index]
|
||||||
|
|
||||||
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
||||||
|
@ -107,6 +117,7 @@ cdef class KnowledgeBase:
|
||||||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||||
Return the hash of the entity ID/name at the end.
|
Return the hash of the entity ID/name at the end.
|
||||||
"""
|
"""
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
|
|
||||||
# Return if this entity was added before
|
# Return if this entity was added before
|
||||||
|
@ -129,6 +140,7 @@ cdef class KnowledgeBase:
|
||||||
return entity_hash
|
return entity_hash
|
||||||
|
|
||||||
cpdef set_entities(self, entity_list, freq_list, vector_list):
|
cpdef set_entities(self, entity_list, freq_list, vector_list):
|
||||||
|
self.require_vocab()
|
||||||
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||||
raise ValueError(Errors.E140)
|
raise ValueError(Errors.E140)
|
||||||
|
|
||||||
|
@ -164,10 +176,12 @@ cdef class KnowledgeBase:
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
def contains_entity(self, unicode entity):
|
def contains_entity(self, unicode entity):
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
return entity_hash in self._entry_index
|
return entity_hash in self._entry_index
|
||||||
|
|
||||||
def contains_alias(self, unicode alias):
|
def contains_alias(self, unicode alias):
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||||
return alias_hash in self._alias_index
|
return alias_hash in self._alias_index
|
||||||
|
|
||||||
|
@ -176,6 +190,7 @@ cdef class KnowledgeBase:
|
||||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||||
Return the alias_hash at the end
|
Return the alias_hash at the end
|
||||||
"""
|
"""
|
||||||
|
self.require_vocab()
|
||||||
# Throw an error if the length of entities and probabilities are not the same
|
# Throw an error if the length of entities and probabilities are not the same
|
||||||
if not len(entities) == len(probabilities):
|
if not len(entities) == len(probabilities):
|
||||||
raise ValueError(Errors.E132.format(alias=alias,
|
raise ValueError(Errors.E132.format(alias=alias,
|
||||||
|
@ -219,6 +234,7 @@ cdef class KnowledgeBase:
|
||||||
Throw an error if this entity+prior prob would exceed the sum of 1.
|
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||||
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||||
"""
|
"""
|
||||||
|
self.require_vocab()
|
||||||
# Check if the alias exists in the KB
|
# Check if the alias exists in the KB
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
if not alias_hash in self._alias_index:
|
if not alias_hash in self._alias_index:
|
||||||
|
@ -265,6 +281,7 @@ cdef class KnowledgeBase:
|
||||||
and the prior probability of that alias resolving to that entity.
|
and the prior probability of that alias resolving to that entity.
|
||||||
If the alias is not known in the KB, and empty list is returned.
|
If the alias is not known in the KB, and empty list is returned.
|
||||||
"""
|
"""
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
if not alias_hash in self._alias_index:
|
if not alias_hash in self._alias_index:
|
||||||
return []
|
return []
|
||||||
|
@ -281,6 +298,7 @@ cdef class KnowledgeBase:
|
||||||
if entry_index != 0]
|
if entry_index != 0]
|
||||||
|
|
||||||
def get_vector(self, unicode entity):
|
def get_vector(self, unicode entity):
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
# Return an empty list if this entity is unknown in this KB
|
# Return an empty list if this entity is unknown in this KB
|
||||||
|
@ -293,6 +311,7 @@ cdef class KnowledgeBase:
|
||||||
def get_prior_prob(self, unicode entity, unicode alias):
|
def get_prior_prob(self, unicode entity, unicode alias):
|
||||||
""" Return the prior probability of a given alias being linked to a given entity,
|
""" Return the prior probability of a given alias being linked to a given entity,
|
||||||
or return 0.0 when this combination is not known in the knowledge base"""
|
or return 0.0 when this combination is not known in the knowledge base"""
|
||||||
|
self.require_vocab()
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
|
@ -311,6 +330,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
|
|
||||||
def dump(self, loc):
|
def dump(self, loc):
|
||||||
|
self.require_vocab()
|
||||||
cdef Writer writer = Writer(loc)
|
cdef Writer writer = Writer(loc)
|
||||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,13 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
||||||
@registry.assets.register("spacy.KBFromFile.v1")
|
@registry.assets.register("spacy.KBFromFile.v1")
|
||||||
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
|
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
|
||||||
vocab = Vocab().from_disk(vocab_path)
|
vocab = Vocab().from_disk(vocab_path)
|
||||||
kb = KnowledgeBase(vocab=vocab)
|
kb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
kb.initialize(vocab)
|
||||||
kb.load_bulk(kb_path)
|
kb.load_bulk(kb_path)
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
@registry.assets.register("spacy.EmptyKB.v1")
|
||||||
|
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
|
||||||
|
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
|
||||||
|
return kb
|
||||||
|
|
|
@ -33,24 +33,31 @@ dropout = null
|
||||||
"""
|
"""
|
||||||
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
|
|
||||||
|
default_kb_config = """
|
||||||
|
[kb]
|
||||||
|
@assets = "spacy.EmptyKB.v1"
|
||||||
|
entity_vector_length = 64
|
||||||
|
"""
|
||||||
|
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||||
assigns=["token.ent_kb_id"],
|
assigns=["token.ent_kb_id"],
|
||||||
default_config={
|
default_config={
|
||||||
"kb": None, # TODO - what kind of default makes sense here?
|
"kb": DEFAULT_NEL_KB,
|
||||||
|
"model": DEFAULT_NEL_MODEL,
|
||||||
"labels_discard": [],
|
"labels_discard": [],
|
||||||
"incl_prior": True,
|
"incl_prior": True,
|
||||||
"incl_context": True,
|
"incl_context": True,
|
||||||
"model": DEFAULT_NEL_MODEL,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def make_entity_linker(
|
def make_entity_linker(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
name: str,
|
name: str,
|
||||||
model: Model,
|
model: Model,
|
||||||
kb: Optional[KnowledgeBase],
|
kb: KnowledgeBase,
|
||||||
*,
|
*,
|
||||||
labels_discard: Iterable[str],
|
labels_discard: Iterable[str],
|
||||||
incl_prior: bool,
|
incl_prior: bool,
|
||||||
|
@ -92,10 +99,10 @@ class EntityLinker(Pipe):
|
||||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
kb (KnowledgeBase): TODO:
|
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
|
||||||
labels_discard (Iterable[str]): TODO:
|
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||||
incl_prior (bool): TODO:
|
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||||
incl_context (bool): TODO:
|
incl_context (bool): Whether or not to include the local context in the model.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#init
|
DOCS: https://spacy.io/api/entitylinker#init
|
||||||
"""
|
"""
|
||||||
|
@ -108,14 +115,12 @@ class EntityLinker(Pipe):
|
||||||
"incl_prior": incl_prior,
|
"incl_prior": incl_prior,
|
||||||
"incl_context": incl_context,
|
"incl_context": incl_context,
|
||||||
}
|
}
|
||||||
self.kb = kb
|
if not isinstance(kb, KnowledgeBase):
|
||||||
if self.kb is None:
|
|
||||||
# create an empty KB that should be filled by calling from_disk
|
|
||||||
self.kb = KnowledgeBase(vocab=vocab)
|
|
||||||
else:
|
|
||||||
del cfg["kb"] # we don't want to duplicate its serialization
|
|
||||||
if not isinstance(self.kb, KnowledgeBase):
|
|
||||||
raise ValueError(Errors.E990.format(type=type(self.kb)))
|
raise ValueError(Errors.E990.format(type=type(self.kb)))
|
||||||
|
kb.initialize(vocab)
|
||||||
|
self.kb = kb
|
||||||
|
if "kb" in cfg:
|
||||||
|
del cfg["kb"] # we don't want to duplicate its serialization
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.distance = CosineDistance(normalize=False)
|
self.distance = CosineDistance(normalize=False)
|
||||||
# how many neightbour sentences to take into account
|
# how many neightbour sentences to take into account
|
||||||
|
@ -437,9 +442,8 @@ class EntityLinker(Pipe):
|
||||||
raise ValueError(Errors.E149)
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
def load_kb(p):
|
def load_kb(p):
|
||||||
self.kb = KnowledgeBase(
|
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
|
||||||
vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]
|
self.kb.initialize(self.vocab)
|
||||||
)
|
|
||||||
self.kb.load_bulk(p)
|
self.kb.load_bulk(p)
|
||||||
|
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
|
|
|
@ -21,7 +21,8 @@ def assert_almost_equal(a, b):
|
||||||
|
|
||||||
def test_kb_valid_entities(nlp):
|
def test_kb_valid_entities(nlp):
|
||||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(entity_vector_length=3)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
|
||||||
|
@ -50,7 +51,8 @@ def test_kb_valid_entities(nlp):
|
||||||
|
|
||||||
def test_kb_invalid_entities(nlp):
|
def test_kb_invalid_entities(nlp):
|
||||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
|
@ -66,7 +68,8 @@ def test_kb_invalid_entities(nlp):
|
||||||
|
|
||||||
def test_kb_invalid_probabilities(nlp):
|
def test_kb_invalid_probabilities(nlp):
|
||||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
|
@ -80,7 +83,8 @@ def test_kb_invalid_probabilities(nlp):
|
||||||
|
|
||||||
def test_kb_invalid_combination(nlp):
|
def test_kb_invalid_combination(nlp):
|
||||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
|
@ -96,7 +100,8 @@ def test_kb_invalid_combination(nlp):
|
||||||
|
|
||||||
def test_kb_invalid_entity_vector(nlp):
|
def test_kb_invalid_entity_vector(nlp):
|
||||||
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
|
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(entity_vector_length=3)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
|
||||||
|
@ -106,9 +111,44 @@ def test_kb_invalid_entity_vector(nlp):
|
||||||
mykb.add_entity(entity="Q2", freq=5, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=5, entity_vector=[2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_default(nlp):
|
||||||
|
"""Test that the default (empty) KB is loaded when not providing a config"""
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||||
|
assert len(entity_linker.kb) == 0
|
||||||
|
assert entity_linker.kb.get_size_entities() == 0
|
||||||
|
assert entity_linker.kb.get_size_aliases() == 0
|
||||||
|
assert entity_linker.kb.entity_vector_length == 64 # default value from pipeline.entity_linker
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_custom_length(nlp):
|
||||||
|
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config={"kb": {"entity_vector_length": 35}})
|
||||||
|
assert len(entity_linker.kb) == 0
|
||||||
|
assert entity_linker.kb.get_size_entities() == 0
|
||||||
|
assert entity_linker.kb.get_size_aliases() == 0
|
||||||
|
assert entity_linker.kb.entity_vector_length == 35
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_undefined(nlp):
|
||||||
|
"""Test that the EL can't train without defining a KB"""
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
entity_linker.begin_training()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_empty(nlp):
|
||||||
|
"""Test that the EL can't train with an empty KB"""
|
||||||
|
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||||
|
assert len(entity_linker.kb) == 0
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
entity_linker.begin_training()
|
||||||
|
|
||||||
|
|
||||||
def test_candidate_generation(nlp):
|
def test_candidate_generation(nlp):
|
||||||
"""Test correct candidate generation"""
|
"""Test correct candidate generation"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||||
|
@ -133,7 +173,8 @@ def test_candidate_generation(nlp):
|
||||||
|
|
||||||
def test_append_alias(nlp):
|
def test_append_alias(nlp):
|
||||||
"""Test that we can append additional alias-entity pairs"""
|
"""Test that we can append additional alias-entity pairs"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||||
|
@ -163,7 +204,8 @@ def test_append_alias(nlp):
|
||||||
|
|
||||||
def test_append_invalid_alias(nlp):
|
def test_append_invalid_alias(nlp):
|
||||||
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||||
|
@ -184,7 +226,8 @@ def test_preserving_links_asdoc(nlp):
|
||||||
|
|
||||||
@registry.assets.register("myLocationsKB.v1")
|
@registry.assets.register("myLocationsKB.v1")
|
||||||
def dummy_kb() -> KnowledgeBase:
|
def dummy_kb() -> KnowledgeBase:
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||||
|
@ -289,7 +332,8 @@ def test_overfitting_IO():
|
||||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||||
# Q2146908 (Russ Cochran): American golfer
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
# Q7381115 (Russ Cochran): publisher
|
# Q7381115 (Russ Cochran): publisher
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(entity_vector_length=3)
|
||||||
|
mykb.initialize(nlp.vocab)
|
||||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||||
mykb.add_alias(
|
mykb.add_alias(
|
||||||
|
|
|
@ -139,7 +139,8 @@ def test_issue4665():
|
||||||
def test_issue4674():
|
def test_issue4674():
|
||||||
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
||||||
nlp = English()
|
nlp = English()
|
||||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
kb = KnowledgeBase(entity_vector_length=3)
|
||||||
|
kb.initialize(nlp.vocab)
|
||||||
vector1 = [0.9, 1.1, 1.01]
|
vector1 = [0.9, 1.1, 1.01]
|
||||||
vector2 = [1.8, 2.25, 2.01]
|
vector2 = [1.8, 2.25, 2.01]
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
|
@ -156,7 +157,8 @@ def test_issue4674():
|
||||||
dir_path.mkdir()
|
dir_path.mkdir()
|
||||||
file_path = dir_path / "kb"
|
file_path = dir_path / "kb"
|
||||||
kb.dump(str(file_path))
|
kb.dump(str(file_path))
|
||||||
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
|
kb2 = KnowledgeBase(entity_vector_length=3)
|
||||||
|
kb2.initialize(nlp.vocab)
|
||||||
kb2.load_bulk(str(file_path))
|
kb2.load_bulk(str(file_path))
|
||||||
assert kb2.get_size_entities() == 1
|
assert kb2.get_size_entities() == 1
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,8 @@ def entity_linker():
|
||||||
|
|
||||||
@registry.assets.register("TestIssue5230KB.v1")
|
@registry.assets.register("TestIssue5230KB.v1")
|
||||||
def dummy_kb() -> KnowledgeBase:
|
def dummy_kb() -> KnowledgeBase:
|
||||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
kb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
kb.initialize(nlp.vocab)
|
||||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
|
@ -121,7 +122,8 @@ def test_writer_with_path_py35():
|
||||||
|
|
||||||
def test_save_and_load_knowledge_base():
|
def test_save_and_load_knowledge_base():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
kb = KnowledgeBase(entity_vector_length=1)
|
||||||
|
kb.initialize(nlp.vocab)
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
path = d / "kb"
|
path = d / "kb"
|
||||||
try:
|
try:
|
||||||
|
@ -130,7 +132,8 @@ def test_save_and_load_knowledge_base():
|
||||||
pytest.fail(str(e))
|
pytest.fail(str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
kb_loaded = KnowledgeBase(entity_vector_length=1)
|
||||||
|
kb_loaded.initialize(nlp.vocab)
|
||||||
kb_loaded.load_bulk(path)
|
kb_loaded.load_bulk(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(str(e))
|
pytest.fail(str(e))
|
||||||
|
|
|
@ -17,7 +17,8 @@ def test_serialize_kb_disk(en_vocab):
|
||||||
file_path = dir_path / "kb"
|
file_path = dir_path / "kb"
|
||||||
kb1.dump(str(file_path))
|
kb1.dump(str(file_path))
|
||||||
|
|
||||||
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
|
kb2 = KnowledgeBase(entity_vector_length=3)
|
||||||
|
kb2.initialize(en_vocab)
|
||||||
kb2.load_bulk(str(file_path))
|
kb2.load_bulk(str(file_path))
|
||||||
|
|
||||||
# final assertions
|
# final assertions
|
||||||
|
@ -25,7 +26,8 @@ def test_serialize_kb_disk(en_vocab):
|
||||||
|
|
||||||
|
|
||||||
def _get_dummy_kb(vocab):
|
def _get_dummy_kb(vocab):
|
||||||
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
|
kb = KnowledgeBase(entity_vector_length=3)
|
||||||
|
kb.initialize(vocab)
|
||||||
|
|
||||||
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
|
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
|
||||||
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
|
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
|
||||||
|
|
Loading…
Reference in New Issue