replace assert's with custom error messages

This commit is contained in:
svlandeg 2019-07-23 11:52:48 +02:00
parent cd6c263fe4
commit 400ff342cf
2 changed files with 18 additions and 6 deletions

View File

@ -409,6 +409,10 @@ class Errors(object):
E144 = ("Could not find parameter `{param}` when building the entity linker model.")
E145 = ("Error reading `{param}` from input file.")
E146 = ("Could not access `{path}`.")
E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
"This is likely a bug in spaCy, so feel free to open an issue.")
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
"is assigned to a KB identifier.")
@add_codes

View File

@ -1160,7 +1160,9 @@ class EntityLinker(Pipe):
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset["{}_{}".format(start, end)]
assert gold_ent is not None
if gold_ent is None:
raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1
@ -1176,7 +1178,8 @@ class EntityLinker(Pipe):
priors.append([0])
if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
@ -1204,7 +1207,8 @@ class EntityLinker(Pipe):
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32")
assert len(scores) == len(cats)
if len(scores) != len(cats):
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
d_scores = (scores - cats)
loss = (d_scores ** 2).sum()
@ -1267,7 +1271,9 @@ class EntityLinker(Pipe):
if self.cfg.get("context_weight", 1) > 0:
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
assert len(entity_encodings) == len(prior_probs)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
+ list(prior_probs[i]) + type_vector
for i in range(len(entity_encodings))]
@ -1279,13 +1285,15 @@ class EntityLinker(Pipe):
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
assert len(final_tensors) == len(final_kb_ids) == entity_count
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
return final_kb_ids, final_tensors
def set_annotations(self, docs, kb_ids, tensors=None):
count_ents = len([ent for doc in docs for ent in doc.ents])
assert count_ents == len(kb_ids)
if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i=0
for doc in docs: