Default empty KB in EL component (#5872)

* EL field documentation

* documentation consistent with docs

* default empty KB, initialize vocab separately

* formatting

* add test for changing the default entity vector length

* update comment
This commit is contained in:
Sofie Van Landeghem 2020-08-04 14:34:09 +02:00 committed by GitHub
parent b7e3018d97
commit 82347110f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 127 additions and 40 deletions

View File

@ -48,7 +48,8 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality. # You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead. # For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1] vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim) kb = KnowledgeBase(entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
# set up the data # set up the data
entity_ids = [] entity_ids = []
@ -95,7 +96,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path) print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path) print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path) vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(vocab=vocab2) kb2 = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab2)
kb2.load_bulk(kb_path) kb2.load_bulk(kb_path)
print() print()
_print_kb(kb2) _print_kb(kb2)

View File

@ -374,7 +374,8 @@ class Errors:
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see " "includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
E139 = ("Knowledge Base for component '{name}' is empty.") E139 = ("Knowledge Base for component '{name}' is empty. Use the methods "
"kb.add_entity and kb.add_alias to add entries.")
E140 = ("The list of entities, prior probabilities and entity vectors " E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.") "should be of equal length.")
E141 = ("Entity vectors should be of length {required} instead of the " E141 = ("Entity vectors should be of length {required} instead of the "
@ -481,6 +482,8 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected " E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'") "a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected " E948 = ("Matcher.add received invalid 'patterns' argument: expected "

View File

@ -71,17 +71,25 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
def __init__(self, Vocab vocab, entity_vector_length=64): def __init__(self, entity_vector_length):
self.vocab = vocab """Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
self.mem = Pool() self.mem = Pool()
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap() self._entry_index = PreshMap()
self._alias_index = PreshMap() self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab
self.vocab.strings.add("") self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property @property
def entity_vector_length(self): def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors""" """RETURNS (uint64): length of the entity vectors"""
@ -94,12 +102,14 @@ cdef class KnowledgeBase:
return len(self._entry_index) return len(self._entry_index)
def get_entity_strings(self): def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index] return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self): def get_size_aliases(self):
return len(self._alias_index) return len(self._alias_index)
def get_alias_strings(self): def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index] return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector): def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -107,6 +117,7 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
""" """
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before # Return if this entity was added before
@ -129,6 +140,7 @@ cdef class KnowledgeBase:
return entity_hash return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list): cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140) raise ValueError(Errors.E140)
@ -164,10 +176,12 @@ cdef class KnowledgeBase:
i += 1 i += 1
def contains_entity(self, unicode entity): def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index return entity_hash in self._entry_index
def contains_alias(self, unicode alias): def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias) cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index return alias_hash in self._alias_index
@ -176,6 +190,7 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB. For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end Return the alias_hash at the end
""" """
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias, raise ValueError(Errors.E132.format(alias=alias,
@ -219,6 +234,7 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1. Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
""" """
self.require_vocab()
# Check if the alias exists in the KB # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
@ -265,6 +281,7 @@ cdef class KnowledgeBase:
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned. If the alias is not known in the KB, and empty list is returned.
""" """
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
return [] return []
@ -281,6 +298,7 @@ cdef class KnowledgeBase:
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity): def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
@ -293,6 +311,7 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias): def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity, """ Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base""" or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
@ -311,6 +330,7 @@ cdef class KnowledgeBase:
def dump(self, loc): def dump(self, loc):
self.require_vocab()
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)

View File

@ -27,6 +27,13 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
@registry.assets.register("spacy.KBFromFile.v1") @registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase: def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
vocab = Vocab().from_disk(vocab_path) vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(vocab=vocab) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab)
kb.load_bulk(kb_path) kb.load_bulk(kb_path)
return kb return kb
@registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
return kb

View File

@ -33,24 +33,31 @@ dropout = null
""" """
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory( @Language.factory(
"entity_linker", "entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb": None, # TODO - what kind of default makes sense here? "kb": DEFAULT_NEL_KB,
"model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"model": DEFAULT_NEL_MODEL,
}, },
) )
def make_entity_linker( def make_entity_linker(
nlp: Language, nlp: Language,
name: str, name: str,
model: Model, model: Model,
kb: Optional[KnowledgeBase], kb: KnowledgeBase,
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
@ -92,10 +99,10 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
kb (KnowledgeBase): TODO: kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
labels_discard (Iterable[str]): TODO: labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): TODO: incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): TODO: incl_context (bool): Whether or not to include the local context in the model.
DOCS: https://spacy.io/api/entitylinker#init DOCS: https://spacy.io/api/entitylinker#init
""" """
@ -108,14 +115,12 @@ class EntityLinker(Pipe):
"incl_prior": incl_prior, "incl_prior": incl_prior,
"incl_context": incl_context, "incl_context": incl_context,
} }
self.kb = kb if not isinstance(kb, KnowledgeBase):
if self.kb is None:
# create an empty KB that should be filled by calling from_disk
self.kb = KnowledgeBase(vocab=vocab)
else:
del cfg["kb"] # we don't want to duplicate its serialization
if not isinstance(self.kb, KnowledgeBase):
raise ValueError(Errors.E990.format(type=type(self.kb))) raise ValueError(Errors.E990.format(type=type(self.kb)))
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
@ -437,9 +442,8 @@ class EntityLinker(Pipe):
raise ValueError(Errors.E149) raise ValueError(Errors.E149)
def load_kb(p): def load_kb(p):
self.kb = KnowledgeBase( self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
vocab=self.vocab, entity_vector_length=self.cfg["entity_width"] self.kb.initialize(self.vocab)
)
self.kb.load_bulk(p) self.kb.load_bulk(p)
deserialize = {} deserialize = {}

View File

@ -21,7 +21,8 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -50,7 +51,8 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -66,7 +68,8 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities""" """Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -80,7 +83,8 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists""" """Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -96,7 +100,8 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths""" """Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -106,9 +111,44 @@ def test_kb_invalid_entity_vector(nlp):
mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=5, entity_vector=[2])
def test_kb_default(nlp):
"""Test that the default (empty) KB is loaded when not providing a config"""
entity_linker = nlp.add_pipe("entity_linker", config={})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
assert entity_linker.kb.entity_vector_length == 64 # default value from pipeline.entity_linker
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe("entity_linker", config={"kb": {"entity_vector_length": 35}})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
assert entity_linker.kb.entity_vector_length == 35
def test_kb_undefined(nlp):
"""Test that the EL can't train without defining a KB"""
entity_linker = nlp.add_pipe("entity_linker", config={})
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -133,7 +173,8 @@ def test_candidate_generation(nlp):
def test_append_alias(nlp): def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs""" """Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,7 +204,8 @@ def test_append_alias(nlp):
def test_append_invalid_alias(nlp): def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1""" """Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -184,7 +226,8 @@ def test_preserving_links_asdoc(nlp):
@registry.assets.register("myLocationsKB.v1") @registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> KnowledgeBase:
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
@ -289,7 +332,8 @@ def test_overfitting_IO():
# create artificial KB - assign same prior weight to the two russ cochran's # create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer # Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher # Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias( mykb.add_alias(

View File

@ -139,7 +139,8 @@ def test_issue4665():
def test_issue4674(): def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO""" """Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English() nlp = English()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3) kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(nlp.vocab)
vector1 = [0.9, 1.1, 1.01] vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01] vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
@ -156,7 +157,8 @@ def test_issue4674():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb.dump(str(file_path)) kb.dump(str(file_path))
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3) kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(nlp.vocab)
kb2.load_bulk(str(file_path)) kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1 assert kb2.get_size_entities() == 1

View File

@ -72,7 +72,8 @@ def entity_linker():
@registry.assets.register("TestIssue5230KB.v1") @registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> KnowledgeBase:
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
@ -121,7 +122,8 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base(): def test_save_and_load_knowledge_base():
nlp = Language() nlp = Language()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
with make_tempdir() as d: with make_tempdir() as d:
path = d / "kb" path = d / "kb"
try: try:
@ -130,7 +132,8 @@ def test_save_and_load_knowledge_base():
pytest.fail(str(e)) pytest.fail(str(e))
try: try:
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb_loaded = KnowledgeBase(entity_vector_length=1)
kb_loaded.initialize(nlp.vocab)
kb_loaded.load_bulk(path) kb_loaded.load_bulk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))

View File

@ -17,7 +17,8 @@ def test_serialize_kb_disk(en_vocab):
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb1.dump(str(file_path)) kb1.dump(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3) kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(en_vocab)
kb2.load_bulk(str(file_path)) kb2.load_bulk(str(file_path))
# final assertions # final assertions
@ -25,7 +26,8 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab): def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(vocab)
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3]) kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0]) kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])