diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 13f7f2771..60b711741 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -6,6 +6,7 @@ from spacy.attrs import ORTH, LENGTH from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.errors import ModelsWarning +from spacy.util import filter_spans from ..util import get_doc @@ -219,3 +220,21 @@ def test_span_ents_property(doc): assert sentences[2].ents[0].label_ == "PRODUCT" assert sentences[2].ents[0].start == 11 assert sentences[2].ents[0].end == 14 + + +def test_filter_spans(doc): + # Test filtering duplicates + spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]] + filtered = filter_spans(spans) + assert len(filtered) == 3 + assert filtered[0].start == 1 and filtered[0].end == 4 + assert filtered[1].start == 6 and filtered[1].end == 8 + assert filtered[2].start == 10 and filtered[2].end == 14 + # Test filtering overlaps with longest preference + spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]] + filtered = filter_spans(spans) + assert len(filtered) == 2 + assert len(filtered[0]) == 3 + assert len(filtered[1]) == 5 + assert filtered[0].start == 1 and filtered[0].end == 4 + assert filtered[1].start == 5 and filtered[1].end == 10 diff --git a/spacy/util.py b/spacy/util.py index 59498ca77..475d556d0 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000): raise StopIteration +def filter_spans(spans): + """Filter a sequence of spans and remove duplicates or overlaps. Useful for + creating named entities (where one token can only be part of one entity) or + when merging spans with `Retokenizer.merge`. When spans overlap, the (first) + longest span is preferred over shorter spans. + + spans (iterable): The spans to filter. + RETURNS (list): The filtered spans. + """ + get_sort_key = lambda span: (span.end - span.start, span.start) + sorted_spans = sorted(spans, key=get_sort_key, reverse=True) + result = [] + seen_tokens = set() + for span in sorted_spans: + # Check for end - 1 here because boundaries are inclusive + if span.start not in seen_tokens and span.end - 1 not in seen_tokens: + result.append(span) + seen_tokens.update(range(span.start, span.end)) + result = sorted(result, key=lambda span: span.start) + return result + + def to_bytes(getters, exclude): serialized = OrderedDict() for key, getter in getters.items(): diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index b9b3cc762..e687cefa8 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias. | `buffsize` | int | Items to hold back. | | **YIELDS** | iterable | The shuffled iterator. | +### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"} + +Filter a sequence of [`Span`](/api/span) objects and remove duplicates or +overlaps. Useful for creating named entities (where one token can only be part +of one entity) or when merging spans with +[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the +(first) longest span is preferred over shorter spans. + +> #### Example +> +> ```python +> doc = nlp("This is a sentence.") +> spans = [doc[0:2], doc[0:2], doc[0:4]] +> filtered = filter_spans(spans) +> ``` + +| Name | Type | Description | +| ----------- | -------- | -------------------- | +| `spans` | iterable | The spans to filter. | +| **RETURNS** | list | The filtered spans. | + ## Compatibility functions {#compat source="spacy/compaty.py"} All Python code is written in an **intersection of Python 2 and Python 3**. This