mirror of https://github.com/explosion/spaCy.git
Add util.filter_spans helper (#3686)
This commit is contained in:
parent
dd1e6b0bc6
commit
505c9e0e19
|
@ -6,6 +6,7 @@ from spacy.attrs import ORTH, LENGTH
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.errors import ModelsWarning
|
from spacy.errors import ModelsWarning
|
||||||
|
from spacy.util import filter_spans
|
||||||
|
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
@ -219,3 +220,21 @@ def test_span_ents_property(doc):
|
||||||
assert sentences[2].ents[0].label_ == "PRODUCT"
|
assert sentences[2].ents[0].label_ == "PRODUCT"
|
||||||
assert sentences[2].ents[0].start == 11
|
assert sentences[2].ents[0].start == 11
|
||||||
assert sentences[2].ents[0].end == 14
|
assert sentences[2].ents[0].end == 14
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_spans(doc):
|
||||||
|
# Test filtering duplicates
|
||||||
|
spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]]
|
||||||
|
filtered = filter_spans(spans)
|
||||||
|
assert len(filtered) == 3
|
||||||
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
|
assert filtered[1].start == 6 and filtered[1].end == 8
|
||||||
|
assert filtered[2].start == 10 and filtered[2].end == 14
|
||||||
|
# Test filtering overlaps with longest preference
|
||||||
|
spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]]
|
||||||
|
filtered = filter_spans(spans)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert len(filtered[0]) == 3
|
||||||
|
assert len(filtered[1]) == 5
|
||||||
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
|
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||||
|
|
|
@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000):
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
|
|
||||||
|
def filter_spans(spans):
|
||||||
|
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
||||||
|
creating named entities (where one token can only be part of one entity) or
|
||||||
|
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
||||||
|
longest span is preferred over shorter spans.
|
||||||
|
|
||||||
|
spans (iterable): The spans to filter.
|
||||||
|
RETURNS (list): The filtered spans.
|
||||||
|
"""
|
||||||
|
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||||
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
|
result = []
|
||||||
|
seen_tokens = set()
|
||||||
|
for span in sorted_spans:
|
||||||
|
# Check for end - 1 here because boundaries are inclusive
|
||||||
|
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||||
|
result.append(span)
|
||||||
|
seen_tokens.update(range(span.start, span.end))
|
||||||
|
result = sorted(result, key=lambda span: span.start)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def to_bytes(getters, exclude):
|
def to_bytes(getters, exclude):
|
||||||
serialized = OrderedDict()
|
serialized = OrderedDict()
|
||||||
for key, getter in getters.items():
|
for key, getter in getters.items():
|
||||||
|
|
|
@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias.
|
||||||
| `buffsize` | int | Items to hold back. |
|
| `buffsize` | int | Items to hold back. |
|
||||||
| **YIELDS** | iterable | The shuffled iterator. |
|
| **YIELDS** | iterable | The shuffled iterator. |
|
||||||
|
|
||||||
|
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
|
||||||
|
|
||||||
|
Filter a sequence of [`Span`](/api/span) objects and remove duplicates or
|
||||||
|
overlaps. Useful for creating named entities (where one token can only be part
|
||||||
|
of one entity) or when merging spans with
|
||||||
|
[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the
|
||||||
|
(first) longest span is preferred over shorter spans.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> doc = nlp("This is a sentence.")
|
||||||
|
> spans = [doc[0:2], doc[0:2], doc[0:4]]
|
||||||
|
> filtered = filter_spans(spans)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | -------- | -------------------- |
|
||||||
|
| `spans` | iterable | The spans to filter. |
|
||||||
|
| **RETURNS** | list | The filtered spans. |
|
||||||
|
|
||||||
## Compatibility functions {#compat source="spacy/compaty.py"}
|
## Compatibility functions {#compat source="spacy/compaty.py"}
|
||||||
|
|
||||||
All Python code is written in an **intersection of Python 2 and Python 3**. This
|
All Python code is written in an **intersection of Python 2 and Python 3**. This
|
||||||
|
|
Loading…
Reference in New Issue