From f7d1736241a3cc583297633d7b14878b1fefdb6e Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Mon, 30 Sep 2019 12:43:48 +0200 Subject: [PATCH] Skip duplicate spans in Doc.retokenize (#4339) --- spacy/tests/doc/test_retokenize_merge.py | 11 +++++++++++ spacy/tokens/_retokenize.pyx | 5 +++++ 2 files changed, 16 insertions(+) diff --git a/spacy/tests/doc/test_retokenize_merge.py b/spacy/tests/doc/test_retokenize_merge.py index 28f00aa18..5bdf78f39 100644 --- a/spacy/tests/doc/test_retokenize_merge.py +++ b/spacy/tests/doc/test_retokenize_merge.py @@ -414,3 +414,14 @@ def test_doc_retokenizer_merge_lex_attrs(en_vocab): assert doc[1].is_stop assert not doc[0].is_stop assert not doc[1].like_num + + +def test_retokenize_skip_duplicates(en_vocab): + """Test that the retokenizer automatically skips duplicate spans instead + of complaining about overlaps. See #3687.""" + doc = Doc(en_vocab, words=["hello", "world", "!"]) + with doc.retokenize() as retokenizer: + retokenizer.merge(doc[0:2]) + retokenizer.merge(doc[0:2]) + assert len(doc) == 2 + assert doc[0].text == "hello world" diff --git a/spacy/tokens/_retokenize.pyx b/spacy/tokens/_retokenize.pyx index 5b0747fa0..f8b13dd78 100644 --- a/spacy/tokens/_retokenize.pyx +++ b/spacy/tokens/_retokenize.pyx @@ -35,12 +35,14 @@ cdef class Retokenizer: cdef list merges cdef list splits cdef set tokens_to_merge + cdef list _spans_to_merge def __init__(self, doc): self.doc = doc self.merges = [] self.splits = [] self.tokens_to_merge = set() + self._spans_to_merge = [] # keep a record to filter out duplicates def merge(self, Span span, attrs=SimpleFrozenDict()): """Mark a span for merging. The attrs will be applied to the resulting @@ -51,10 +53,13 @@ cdef class Retokenizer: DOCS: https://spacy.io/api/doc#retokenizer.merge """ + if (span.start, span.end) in self._spans_to_merge: + return for token in span: if token.i in self.tokens_to_merge: raise ValueError(Errors.E102.format(token=repr(token))) self.tokens_to_merge.add(token.i) + self._spans_to_merge.append((span.start, span.end)) if "_" in attrs: # Extension attributes extensions = attrs["_"] _validate_extensions(extensions)