spaCy/examples/information_extraction/entity_relations.py

83 lines
2.7 KiB
Python
Raw Normal View History

2017-10-26 16:46:11 +00:00
#!/usr/bin/env python
# coding: utf8
2017-10-31 23:43:22 +00:00
"""A simple example of extracting relations between phrases and entities using
2017-10-26 16:46:11 +00:00
spaCy's named entity recognizer and the dependency parse. Here, we extract
money and currency values (entities labelled as MONEY) and then check the
dependency tree to find the noun phrase they are referring to for example:
$9.4 million --> Net income.
2017-11-07 00:22:30 +00:00
Compatible with: spaCy v2.0.0+
Last tested with: v2.2.1
2017-10-26 16:46:11 +00:00
"""
from __future__ import unicode_literals, print_function
import plac
import spacy
TEXTS = [
2018-12-02 03:26:26 +00:00
"Net income was $9.4 million compared to the prior year of $2.7 million.",
"Revenue exceeded twelve billion dollars, with a loss of $1b.",
2017-10-26 16:46:11 +00:00
]
@plac.annotations(
2018-12-02 03:26:26 +00:00
model=("Model to load (needs parser and NER)", "positional", None, str)
)
def main(model="en_core_web_sm"):
2017-10-26 16:46:11 +00:00
nlp = spacy.load(model)
print("Loaded model '%s'" % model)
print("Processing %d texts" % len(TEXTS))
for text in TEXTS:
doc = nlp(text)
relations = extract_currency_relations(doc)
for r1, r2 in relations:
2018-12-02 03:26:26 +00:00
print("{:<10}\t{}\t{}".format(r1.text, r2.ent_type_, r2.text))
2017-10-26 16:46:11 +00:00
2019-05-06 13:13:10 +00:00
def filter_spans(spans):
# Filter a sequence of spans so they don't contain overlaps
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
get_sort_key = lambda span: (span.end - span.start, -span.start)
2019-05-06 13:13:10 +00:00
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
2019-05-06 13:02:11 +00:00
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result
2017-10-26 16:46:11 +00:00
def extract_currency_relations(doc):
# Merge entities and noun chunks into one token
spans = list(doc.ents) + list(doc.noun_chunks)
spans = filter_spans(spans)
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
2017-10-26 16:46:11 +00:00
relations = []
2018-12-02 03:26:26 +00:00
for money in filter(lambda w: w.ent_type_ == "MONEY", doc):
if money.dep_ in ("attr", "dobj"):
subject = [w for w in money.head.lefts if w.dep_ == "nsubj"]
2017-10-26 16:46:11 +00:00
if subject:
subject = subject[0]
relations.append((subject, money))
2018-12-02 03:26:26 +00:00
elif money.dep_ == "pobj" and money.head.dep_ == "prep":
2017-10-26 16:46:11 +00:00
relations.append((money.head.head, money))
return relations
2018-12-02 03:26:26 +00:00
if __name__ == "__main__":
2017-10-26 16:46:11 +00:00
plac.call(main)
# Expected output:
# Net income MONEY $9.4 million
# the prior year MONEY $2.7 million
# Revenue MONEY twelve billion dollars
# a loss MONEY 1b