diff --git a/examples/information_extraction/entity_relations.py b/examples/information_extraction/entity_relations.py index ffc8164e1..593324134 100644 --- a/examples/information_extraction/entity_relations.py +++ b/examples/information_extraction/entity_relations.py @@ -36,11 +36,27 @@ def main(model="en_core_web_sm"): print("{:<10}\t{}\t{}".format(r1.text, r2.ent_type_, r2.text)) -def extract_currency_relations(doc): - # merge entities and noun chunks into one token - spans = list(doc.ents) + list(doc.noun_chunks) +def filter_spans(spans, prefer_longest=True): + # Filter a sequence of spans so they don't contain overlaps + get_sort_key = lambda span: (span.end - span.start, span.start) + sorted_spans = sorted(spans, key=get_sort_key, reverse=prefer_longest) + result = [] + seen_tokens = set() for span in spans: - span.merge() + if span.start not in seen_tokens and span.end - 1 not in seen_tokens: + result.append(span) + seen_tokens.update(range(span.start, span.end)) + return result + + +def extract_currency_relations(doc): + # Merge entities and noun chunks into one token + seen_tokens = set() + spans = list(doc.ents) + list(doc.noun_chunks) + spans = filter_spans(spans) + with doc.retokenize() as retokenizer: + for span in spans: + retokenizer.merge(span) relations = [] for money in filter(lambda w: w.ent_type_ == "MONEY", doc):