mirror of https://github.com/explosion/spaCy.git
check and unit test in case prior probs exceed 1
This commit is contained in:
parent
b55baaa1dc
commit
33f8a0fe2e
|
@ -35,6 +35,13 @@ cdef class KnowledgeBase:
|
|||
|
||||
def add_alias(self, unicode alias, entities, probabilities):
|
||||
"""For a given alias, add its potential entities and prior probabilies to the KB."""
|
||||
|
||||
# Throw an error if the probabilities sum up to more than 1
|
||||
prob_sum = sum(probabilities)
|
||||
if prob_sum > 1:
|
||||
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
|
||||
"but found " + str(prob_sum))
|
||||
|
||||
cdef hash_t alias_hash = self.strings.add(alias)
|
||||
|
||||
# Return if this alias was added before
|
||||
|
|
|
@ -42,6 +42,12 @@ def create_kb():
|
|||
|
||||
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
|
||||
|
||||
alias2 = "johny"
|
||||
print(" adding alias2", alias2)
|
||||
mykb.add_alias(alias=alias2, entities=["Q0", "Q42"], probabilities=[0.3, 1.1])
|
||||
|
||||
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
|
||||
|
||||
print("candidates for", alias)
|
||||
candidates = mykb.get_candidates(alias)
|
||||
print(" ", candidates)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
# coding: utf-8
|
||||
import pytest
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
|
||||
def test_kb_valid_entities():
|
||||
"""Test the valid construction of a KB with 3 entities and one alias"""
|
||||
mykb = KnowledgeBase()
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.5)
|
||||
mykb.add_entity(entity_id="Q2", prob=0.5)
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||
|
||||
# adding aliases
|
||||
|
@ -16,14 +18,29 @@ def test_kb_valid_entities():
|
|||
|
||||
|
||||
def test_kb_invalid_entities():
|
||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||
mykb = KnowledgeBase()
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.5)
|
||||
mykb.add_entity(entity_id="Q2", prob=0.5)
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||
|
||||
# adding aliases - should fail because one of the given IDs is not valid
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
||||
|
||||
|
||||
def test_kb_invalid_probabilities():
|
||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||
mykb = KnowledgeBase()
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||
|
||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||
|
||||
|
|
Loading…
Reference in New Issue