From 33f8a0fe2e45a6d6c6b8f04b08b2789847cbe74f Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 19 Mar 2019 21:43:48 +0100 Subject: [PATCH] check and unit test in case prior probs exceed 1 --- spacy/kb.pyx | 7 +++++++ spacy/sandbox_test_sofie/testing_el.py | 6 ++++++ spacy/tests/pipeline/test_el.py | 25 +++++++++++++++++++++---- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/spacy/kb.pyx b/spacy/kb.pyx index f67519260..2b38202f3 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -35,6 +35,13 @@ cdef class KnowledgeBase: def add_alias(self, unicode alias, entities, probabilities): """For a given alias, add its potential entities and prior probabilies to the KB.""" + + # Throw an error if the probabilities sum up to more than 1 + prob_sum = sum(probabilities) + if prob_sum > 1: + raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, " + "but found " + str(prob_sum)) + cdef hash_t alias_hash = self.strings.add(alias) # Return if this alias was added before diff --git a/spacy/sandbox_test_sofie/testing_el.py b/spacy/sandbox_test_sofie/testing_el.py index 734eddd8d..71fecb7e6 100644 --- a/spacy/sandbox_test_sofie/testing_el.py +++ b/spacy/sandbox_test_sofie/testing_el.py @@ -42,6 +42,12 @@ def create_kb(): print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) + alias2 = "johny" + print(" adding alias2", alias2) + mykb.add_alias(alias=alias2, entities=["Q0", "Q42"], probabilities=[0.3, 1.1]) + + print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) + print("candidates for", alias) candidates = mykb.get_candidates(alias) print(" ", candidates) diff --git a/spacy/tests/pipeline/test_el.py b/spacy/tests/pipeline/test_el.py index ed88076ce..f9533ef82 100644 --- a/spacy/tests/pipeline/test_el.py +++ b/spacy/tests/pipeline/test_el.py @@ -1,14 +1,16 @@ +# coding: utf-8 import pytest from spacy.kb import KnowledgeBase def test_kb_valid_entities(): + """Test the valid construction of a KB with 3 entities and one alias""" mykb = KnowledgeBase() # adding entities - mykb.add_entity(entity_id="Q1", prob=0.5) - mykb.add_entity(entity_id="Q2", prob=0.5) + mykb.add_entity(entity_id="Q1", prob=0.9) + mykb.add_entity(entity_id="Q2", prob=0.2) mykb.add_entity(entity_id="Q3", prob=0.5) # adding aliases @@ -16,14 +18,29 @@ def test_kb_valid_entities(): def test_kb_invalid_entities(): + """Test the invalid construction of a KB with an alias linked to a non-existing entity""" mykb = KnowledgeBase() # adding entities - mykb.add_entity(entity_id="Q1", prob=0.5) - mykb.add_entity(entity_id="Q2", prob=0.5) + mykb.add_entity(entity_id="Q1", prob=0.9) + mykb.add_entity(entity_id="Q2", prob=0.2) mykb.add_entity(entity_id="Q3", prob=0.5) # adding aliases - should fail because one of the given IDs is not valid with pytest.raises(ValueError): mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]) + +def test_kb_invalid_probabilities(): + """Test the invalid construction of a KB with wrong prior probabilities""" + mykb = KnowledgeBase() + + # adding entities + mykb.add_entity(entity_id="Q1", prob=0.9) + mykb.add_entity(entity_id="Q2", prob=0.2) + mykb.add_entity(entity_id="Q3", prob=0.5) + + # adding aliases - should fail because the sum of the probabilities exceeds 1 + with pytest.raises(ValueError): + mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4]) +