mirror of https://github.com/explosion/spaCy.git
Add util.filter_spans helper (#3686)
This commit is contained in:
parent
dd1e6b0bc6
commit
505c9e0e19
|
@ -6,6 +6,7 @@ from spacy.attrs import ORTH, LENGTH
|
|||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.errors import ModelsWarning
|
||||
from spacy.util import filter_spans
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -219,3 +220,21 @@ def test_span_ents_property(doc):
|
|||
assert sentences[2].ents[0].label_ == "PRODUCT"
|
||||
assert sentences[2].ents[0].start == 11
|
||||
assert sentences[2].ents[0].end == 14
|
||||
|
||||
|
||||
def test_filter_spans(doc):
|
||||
# Test filtering duplicates
|
||||
spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]]
|
||||
filtered = filter_spans(spans)
|
||||
assert len(filtered) == 3
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 6 and filtered[1].end == 8
|
||||
assert filtered[2].start == 10 and filtered[2].end == 14
|
||||
# Test filtering overlaps with longest preference
|
||||
spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]]
|
||||
filtered = filter_spans(spans)
|
||||
assert len(filtered) == 2
|
||||
assert len(filtered[0]) == 3
|
||||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
|
|
|
@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000):
|
|||
raise StopIteration
|
||||
|
||||
|
||||
def filter_spans(spans):
|
||||
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
||||
creating named entities (where one token can only be part of one entity) or
|
||||
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
||||
longest span is preferred over shorter spans.
|
||||
|
||||
spans (iterable): The spans to filter.
|
||||
RETURNS (list): The filtered spans.
|
||||
"""
|
||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
result = []
|
||||
seen_tokens = set()
|
||||
for span in sorted_spans:
|
||||
# Check for end - 1 here because boundaries are inclusive
|
||||
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||
result.append(span)
|
||||
seen_tokens.update(range(span.start, span.end))
|
||||
result = sorted(result, key=lambda span: span.start)
|
||||
return result
|
||||
|
||||
|
||||
def to_bytes(getters, exclude):
|
||||
serialized = OrderedDict()
|
||||
for key, getter in getters.items():
|
||||
|
|
|
@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias.
|
|||
| `buffsize` | int | Items to hold back. |
|
||||
| **YIELDS** | iterable | The shuffled iterator. |
|
||||
|
||||
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
|
||||
|
||||
Filter a sequence of [`Span`](/api/span) objects and remove duplicates or
|
||||
overlaps. Useful for creating named entities (where one token can only be part
|
||||
of one entity) or when merging spans with
|
||||
[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the
|
||||
(first) longest span is preferred over shorter spans.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a sentence.")
|
||||
> spans = [doc[0:2], doc[0:2], doc[0:4]]
|
||||
> filtered = filter_spans(spans)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | -------------------- |
|
||||
| `spans` | iterable | The spans to filter. |
|
||||
| **RETURNS** | list | The filtered spans. |
|
||||
|
||||
## Compatibility functions {#compat source="spacy/compaty.py"}
|
||||
|
||||
All Python code is written in an **intersection of Python 2 and Python 3**. This
|
||||
|
|
Loading…
Reference in New Issue