Add util.filter_spans helper (#3686)

This commit is contained in:
Ines Montani 2019-05-08 02:33:40 +02:00 committed by Matthew Honnibal
parent dd1e6b0bc6
commit 505c9e0e19
3 changed files with 62 additions and 0 deletions

View File

@ -6,6 +6,7 @@ from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.errors import ModelsWarning from spacy.errors import ModelsWarning
from spacy.util import filter_spans
from ..util import get_doc from ..util import get_doc
@ -219,3 +220,21 @@ def test_span_ents_property(doc):
assert sentences[2].ents[0].label_ == "PRODUCT" assert sentences[2].ents[0].label_ == "PRODUCT"
assert sentences[2].ents[0].start == 11 assert sentences[2].ents[0].start == 11
assert sentences[2].ents[0].end == 14 assert sentences[2].ents[0].end == 14
def test_filter_spans(doc):
# Test filtering duplicates
spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]]
filtered = filter_spans(spans)
assert len(filtered) == 3
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 6 and filtered[1].end == 8
assert filtered[2].start == 10 and filtered[2].end == 14
# Test filtering overlaps with longest preference
spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]]
filtered = filter_spans(spans)
assert len(filtered) == 2
assert len(filtered[0]) == 3
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10

View File

@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000):
raise StopIteration raise StopIteration
def filter_spans(spans):
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans.
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result
def to_bytes(getters, exclude): def to_bytes(getters, exclude):
serialized = OrderedDict() serialized = OrderedDict()
for key, getter in getters.items(): for key, getter in getters.items():

View File

@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias.
| `buffsize` | int | Items to hold back. | | `buffsize` | int | Items to hold back. |
| **YIELDS** | iterable | The shuffled iterator. | | **YIELDS** | iterable | The shuffled iterator. |
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
Filter a sequence of [`Span`](/api/span) objects and remove duplicates or
overlaps. Useful for creating named entities (where one token can only be part
of one entity) or when merging spans with
[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the
(first) longest span is preferred over shorter spans.
> #### Example
>
> ```python
> doc = nlp("This is a sentence.")
> spans = [doc[0:2], doc[0:2], doc[0:4]]
> filtered = filter_spans(spans)
> ```
| Name | Type | Description |
| ----------- | -------- | -------------------- |
| `spans` | iterable | The spans to filter. |
| **RETURNS** | list | The filtered spans. |
## Compatibility functions {#compat source="spacy/compaty.py"} ## Compatibility functions {#compat source="spacy/compaty.py"}
All Python code is written in an **intersection of Python 2 and Python 3**. This All Python code is written in an **intersection of Python 2 and Python 3**. This