allow small rounding errors

This commit is contained in:
svlandeg 2019-05-01 23:05:40 +02:00
parent 3629a52ede
commit 1ae41daaa9
2 changed files with 8 additions and 5 deletions

View File

@ -61,13 +61,13 @@ def create_kb(vocab, max_entities_per_alias, min_occ, to_print=False):
entity_frequencies = _get_entity_frequencies(entities=title_list)
print()
print("3. _add_entities", datetime.datetime.now())
print("3. adding", len(entity_list), "entities", datetime.datetime.now())
print()
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None, feature_list=None)
# _add_entities(kb, entities=entity_list, probs=entity_frequencies, to_print=to_print)
print()
print("4. _add_aliases", datetime.datetime.now())
print("4. adding aliases", datetime.datetime.now())
print()
_add_aliases(kb, title_to_id=title_to_id, max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,)
@ -171,7 +171,10 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, to_print=Fals
prior_probs.append(p_entity_givenalias)
if selected_entities:
try:
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
except ValueError as e:
print(e)
total_count = 0
counts = list()
entities = list()

View File

@ -179,9 +179,9 @@ cdef class KnowledgeBase:
entities_length=len(entities),
probabilities_length=len(probabilities)))
# Throw an error if the probabilities sum up to more than 1
# Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors)
prob_sum = sum(probabilities)
if prob_sum > 1:
if prob_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
cdef hash_t alias_hash = self.vocab.strings.add(alias)