mirror of https://github.com/explosion/spaCy.git
check and unit test in case prior probs exceed 1
This commit is contained in:
parent
b55baaa1dc
commit
33f8a0fe2e
|
@ -35,6 +35,13 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
def add_alias(self, unicode alias, entities, probabilities):
|
def add_alias(self, unicode alias, entities, probabilities):
|
||||||
"""For a given alias, add its potential entities and prior probabilies to the KB."""
|
"""For a given alias, add its potential entities and prior probabilies to the KB."""
|
||||||
|
|
||||||
|
# Throw an error if the probabilities sum up to more than 1
|
||||||
|
prob_sum = sum(probabilities)
|
||||||
|
if prob_sum > 1:
|
||||||
|
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
|
||||||
|
"but found " + str(prob_sum))
|
||||||
|
|
||||||
cdef hash_t alias_hash = self.strings.add(alias)
|
cdef hash_t alias_hash = self.strings.add(alias)
|
||||||
|
|
||||||
# Return if this alias was added before
|
# Return if this alias was added before
|
||||||
|
|
|
@ -42,6 +42,12 @@ def create_kb():
|
||||||
|
|
||||||
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
|
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
|
||||||
|
|
||||||
|
alias2 = "johny"
|
||||||
|
print(" adding alias2", alias2)
|
||||||
|
mykb.add_alias(alias=alias2, entities=["Q0", "Q42"], probabilities=[0.3, 1.1])
|
||||||
|
|
||||||
|
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
|
||||||
|
|
||||||
print("candidates for", alias)
|
print("candidates for", alias)
|
||||||
candidates = mykb.get_candidates(alias)
|
candidates = mykb.get_candidates(alias)
|
||||||
print(" ", candidates)
|
print(" ", candidates)
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
|
# coding: utf-8
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
def test_kb_valid_entities():
|
def test_kb_valid_entities():
|
||||||
|
"""Test the valid construction of a KB with 3 entities and one alias"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase()
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.5)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
mykb.add_entity(entity_id="Q2", prob=0.5)
|
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||||
mykb.add_entity(entity_id="Q3", prob=0.5)
|
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
|
@ -16,14 +18,29 @@ def test_kb_valid_entities():
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities():
|
def test_kb_invalid_entities():
|
||||||
|
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase()
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.5)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
mykb.add_entity(entity_id="Q2", prob=0.5)
|
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||||
mykb.add_entity(entity_id="Q3", prob=0.5)
|
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||||
|
|
||||||
# adding aliases - should fail because one of the given IDs is not valid
|
# adding aliases - should fail because one of the given IDs is not valid
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_invalid_probabilities():
|
||||||
|
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||||
|
mykb = KnowledgeBase()
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
mykb.add_entity(entity_id="Q2", prob=0.2)
|
||||||
|
mykb.add_entity(entity_id="Q3", prob=0.5)
|
||||||
|
|
||||||
|
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue