Add util.filter_spans helper (#3686)

This commit is contained in:
Ines Montani 2019-05-08 02:33:40 +02:00 committed by Matthew Honnibal
parent dd1e6b0bc6
commit 505c9e0e19
3 changed files with 62 additions and 0 deletions

View File

@ -6,6 +6,7 @@ from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.errors import ModelsWarning
from spacy.util import filter_spans
from ..util import get_doc
@ -219,3 +220,21 @@ def test_span_ents_property(doc):
assert sentences[2].ents[0].label_ == "PRODUCT"
assert sentences[2].ents[0].start == 11
assert sentences[2].ents[0].end == 14
def test_filter_spans(doc):
# Test filtering duplicates
spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]]
filtered = filter_spans(spans)
assert len(filtered) == 3
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 6 and filtered[1].end == 8
assert filtered[2].start == 10 and filtered[2].end == 14
# Test filtering overlaps with longest preference
spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]]
filtered = filter_spans(spans)
assert len(filtered) == 2
assert len(filtered[0]) == 3
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10

View File

@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000):
raise StopIteration
def filter_spans(spans):
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans.
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result
def to_bytes(getters, exclude):
serialized = OrderedDict()
for key, getter in getters.items():

View File

@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias.
| `buffsize` | int | Items to hold back. |
| **YIELDS** | iterable | The shuffled iterator. |
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
Filter a sequence of [`Span`](/api/span) objects and remove duplicates or
overlaps. Useful for creating named entities (where one token can only be part
of one entity) or when merging spans with
[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the
(first) longest span is preferred over shorter spans.
> #### Example
>
> ```python
> doc = nlp("This is a sentence.")
> spans = [doc[0:2], doc[0:2], doc[0:4]]
> filtered = filter_spans(spans)
> ```
| Name | Type | Description |
| ----------- | -------- | -------------------- |
| `spans` | iterable | The spans to filter. |
| **RETURNS** | list | The filtered spans. |
## Compatibility functions {#compat source="spacy/compaty.py"}
All Python code is written in an **intersection of Python 2 and Python 3**. This