spaCy/spacy/tests/serialize/test_serialize_kb.py

74 lines
2.5 KiB
Python
Raw Normal View History

2019-06-19 11:11:39 +00:00
# coding: utf-8
2019-04-24 21:52:34 +00:00
from ..util import make_tempdir
from ...util import ensure_path
from spacy.kb import KnowledgeBase
def test_serialize_kb_disk(en_vocab):
# baseline assertions
kb1 = _get_dummy_kb(en_vocab)
2019-04-24 21:52:34 +00:00
_check_kb(kb1)
# dumping to file & loading back in
with make_tempdir() as d:
dir_path = ensure_path(d)
if not dir_path.exists():
dir_path.mkdir()
file_path = dir_path / "kb"
print(file_path, type(file_path))
kb1.dump(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
2019-04-24 21:52:34 +00:00
kb2.load_bulk(str(file_path))
# final assertions
_check_kb(kb2)
def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
2019-06-19 11:26:33 +00:00
kb.add_entity(entity=u'Q53', prob=0.33, entity_vector=[0, 5, 3])
kb.add_entity(entity=u'Q17', prob=0.2, entity_vector=[7, 1, 0])
kb.add_entity(entity=u'Q007', prob=0.7, entity_vector=[0, 0, 7])
kb.add_entity(entity=u'Q44', prob=0.4, entity_vector=[4, 4, 4])
2019-06-19 11:26:33 +00:00
kb.add_alias(alias=u'double07', entities=[u'Q17', u'Q007'], probabilities=[0.1, 0.9])
kb.add_alias(alias=u'guy', entities=[u'Q53', u'Q007', u'Q17', u'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
kb.add_alias(alias=u'random', entities=[u'Q007'], probabilities=[1.0])
return kb
2019-04-24 21:52:34 +00:00
def _check_kb(kb):
# check entities
assert kb.get_size_entities() == 4
2019-06-19 11:26:33 +00:00
for entity_string in [u'Q53', u'Q17', u'Q007', u'Q44']:
2019-04-24 21:52:34 +00:00
assert entity_string in kb.get_entity_strings()
2019-06-19 11:26:33 +00:00
for entity_string in [u'', u'Q0']:
2019-04-24 21:52:34 +00:00
assert entity_string not in kb.get_entity_strings()
# check aliases
assert kb.get_size_aliases() == 3
2019-06-19 11:26:33 +00:00
for alias_string in [u'double07', u'guy', u'random']:
2019-04-24 21:52:34 +00:00
assert alias_string in kb.get_alias_strings()
2019-06-19 11:26:33 +00:00
for alias_string in [u'nothingness', u'', u'randomnoise']:
2019-04-24 21:52:34 +00:00
assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities
2019-06-19 11:26:33 +00:00
candidates = sorted(kb.get_candidates(u'double07'), key=lambda x: x.entity_)
2019-04-24 21:52:34 +00:00
assert len(candidates) == 2
2019-06-19 11:26:33 +00:00
assert candidates[0].entity_ == u'Q007'
2019-04-29 11:58:07 +00:00
assert 0.6999 < candidates[0].entity_freq < 0.701
assert candidates[0].entity_vector == [0, 0, 7]
2019-06-19 11:26:33 +00:00
assert candidates[0].alias_ == u'double07'
2019-04-29 11:58:07 +00:00
assert 0.899 < candidates[0].prior_prob < 0.901
2019-04-24 21:52:34 +00:00
2019-06-19 11:26:33 +00:00
assert candidates[1].entity_ == u'Q17'
2019-04-29 11:58:07 +00:00
assert 0.199 < candidates[1].entity_freq < 0.201
assert candidates[1].entity_vector == [7, 1, 0]
2019-06-19 11:26:33 +00:00
assert candidates[1].alias_ == u'double07'
2019-04-29 11:58:07 +00:00
assert 0.099 < candidates[1].prior_prob < 0.101