check and unit test in case prior probs exceed 1

This commit is contained in:
svlandeg 2019-03-19 21:43:48 +01:00
parent b55baaa1dc
commit 33f8a0fe2e
3 changed files with 34 additions and 4 deletions

View File

@ -35,6 +35,13 @@ cdef class KnowledgeBase:
def add_alias(self, unicode alias, entities, probabilities): def add_alias(self, unicode alias, entities, probabilities):
"""For a given alias, add its potential entities and prior probabilies to the KB.""" """For a given alias, add its potential entities and prior probabilies to the KB."""
# Throw an error if the probabilities sum up to more than 1
prob_sum = sum(probabilities)
if prob_sum > 1:
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
"but found " + str(prob_sum))
cdef hash_t alias_hash = self.strings.add(alias) cdef hash_t alias_hash = self.strings.add(alias)
# Return if this alias was added before # Return if this alias was added before

View File

@ -42,6 +42,12 @@ def create_kb():
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
alias2 = "johny"
print(" adding alias2", alias2)
mykb.add_alias(alias=alias2, entities=["Q0", "Q42"], probabilities=[0.3, 1.1])
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
print("candidates for", alias) print("candidates for", alias)
candidates = mykb.get_candidates(alias) candidates = mykb.get_candidates(alias)
print(" ", candidates) print(" ", candidates)

View File

@ -1,14 +1,16 @@
# coding: utf-8
import pytest import pytest
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
def test_kb_valid_entities(): def test_kb_valid_entities():
"""Test the valid construction of a KB with 3 entities and one alias"""
mykb = KnowledgeBase() mykb = KnowledgeBase()
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.5) mykb.add_entity(entity_id="Q1", prob=0.9)
mykb.add_entity(entity_id="Q2", prob=0.5) mykb.add_entity(entity_id="Q2", prob=0.2)
mykb.add_entity(entity_id="Q3", prob=0.5) mykb.add_entity(entity_id="Q3", prob=0.5)
# adding aliases # adding aliases
@ -16,14 +18,29 @@ def test_kb_valid_entities():
def test_kb_invalid_entities(): def test_kb_invalid_entities():
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase() mykb = KnowledgeBase()
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.5) mykb.add_entity(entity_id="Q1", prob=0.9)
mykb.add_entity(entity_id="Q2", prob=0.5) mykb.add_entity(entity_id="Q2", prob=0.2)
mykb.add_entity(entity_id="Q3", prob=0.5) mykb.add_entity(entity_id="Q3", prob=0.5)
# adding aliases - should fail because one of the given IDs is not valid # adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
def test_kb_invalid_probabilities():
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase()
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)
mykb.add_entity(entity_id="Q2", prob=0.2)
mykb.add_entity(entity_id="Q3", prob=0.5)
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])