From a9074e0886dd9679fd34a9d82a9618d794af32bf Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 19 Mar 2019 21:55:10 +0100 Subject: [PATCH] check the length of entities and probabilities vector + unit test --- spacy/kb.pyx | 12 ++++++++---- spacy/tests/pipeline/test_el.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/spacy/kb.pyx b/spacy/kb.pyx index bc7cddf11..ba694ce61 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -36,11 +36,18 @@ cdef class KnowledgeBase: def add_alias(self, unicode alias, entities, probabilities): """For a given alias, add its potential entities and prior probabilies to the KB.""" + # Throw an error if the length of entities and probabilities are not the same + if not len(entities) == len(probabilities): + raise ValueError("The vectors for entities and probabilities for alias '" + alias + + "' should have equal length, but found " + + str(len(entities)) + " and " + str(len(probabilities)) + "respectively.") + + # Throw an error if the probabilities sum up to more than 1 prob_sum = sum(probabilities) if prob_sum > 1: raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, " - "but found " + str(prob_sum)) + + "but found " + str(prob_sum)) cdef hash_t alias_hash = self.strings.add(alias) @@ -63,9 +70,6 @@ cdef class KnowledgeBase: entry_indices.push_back(int(entry_index)) probs.push_back(float(prob)) - # TODO: check sum(probabilities) <= 1 - # TODO: check len(entities) == len(probabilities) - self.c_add_aliases(alias_key=alias_hash, entry_indices=entry_indices, probs=probs) diff --git a/spacy/tests/pipeline/test_el.py b/spacy/tests/pipeline/test_el.py index cd71bcb48..068a228d8 100644 --- a/spacy/tests/pipeline/test_el.py +++ b/spacy/tests/pipeline/test_el.py @@ -49,3 +49,17 @@ def test_kb_invalid_probabilities(): with pytest.raises(ValueError): mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4]) + +def test_kb_invalid_combination(): + """Test the invalid construction of a KB with non-matching entity and probability lists""" + mykb = KnowledgeBase() + + # adding entities + mykb.add_entity(entity_id="Q1", prob=0.9) + mykb.add_entity(entity_id="Q2", prob=0.2) + mykb.add_entity(entity_id="Q3", prob=0.5) + + # adding aliases - should fail because the entities and probabilities vectors are not of equal length + with pytest.raises(ValueError): + mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]) +