diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 8bf36e9c2..97b27d7ce 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1464,21 +1464,59 @@ class Sentencizer(object): DOCS: https://spacy.io/api/sentencizer#call """ - start = 0 - seen_period = False - for i, token in enumerate(doc): - is_in_punct_chars = token.text in self.punct_chars - token.is_sent_start = i == 0 - if seen_period and not token.is_punct and not is_in_punct_chars: - doc[start].is_sent_start = True - start = token.i - seen_period = False - elif is_in_punct_chars: - seen_period = True - if start < len(doc): - doc[start].is_sent_start = True + tags = self.predict([doc]) + self.set_annotations([doc], tags) return doc + def pipe(self, stream, batch_size=128, n_threads=-1): + for docs in util.minibatch(stream, size=batch_size): + docs = list(docs) + tag_ids = self.predict(docs) + self.set_annotations(docs, tag_ids) + yield from docs + + def predict(self, docs): + """Apply the pipeline's model to a batch of docs, without + modifying them. + """ + if not any(len(doc) for doc in docs): + # Handle cases where there are no tokens in any docs. + guesses = [[] for doc in docs] + return guesses + guesses = [] + for doc in docs: + start = 0 + seen_period = False + doc_guesses = [False] * len(doc) + doc_guesses[0] = True + for i, token in enumerate(doc): + is_in_punct_chars = token.text in self.punct_chars + if seen_period and not token.is_punct and not is_in_punct_chars: + doc_guesses[start] = True + start = token.i + seen_period = False + elif is_in_punct_chars: + seen_period = True + if start < len(doc): + doc_guesses[start] = True + guesses.append(doc_guesses) + return guesses + + def set_annotations(self, docs, batch_tag_ids, tensors=None): + if isinstance(docs, Doc): + docs = [docs] + cdef Doc doc + cdef int idx = 0 + for i, doc in enumerate(docs): + doc_tag_ids = batch_tag_ids[i] + for j, tag_id in enumerate(doc_tag_ids): + # Don't clobber existing sentence boundaries + if doc.c[j].sent_start == 0: + if tag_id: + doc.c[j].sent_start = 1 + else: + doc.c[j].sent_start = -1 + def to_bytes(self, **kwargs): """Serialize the sentencizer to a bytestring. diff --git a/spacy/tests/pipeline/test_sentencizer.py b/spacy/tests/pipeline/test_sentencizer.py index d91fdd198..359552c5b 100644 --- a/spacy/tests/pipeline/test_sentencizer.py +++ b/spacy/tests/pipeline/test_sentencizer.py @@ -5,6 +5,7 @@ import pytest import spacy from spacy.pipeline import Sentencizer from spacy.tokens import Doc +from spacy.lang.en import English def test_sentencizer(en_vocab): @@ -17,6 +18,17 @@ def test_sentencizer(en_vocab): assert len(list(doc.sents)) == 2 +def test_sentencizer_pipe(): + texts = ["Hello! This is a test.", "Hi! This is a test."] + nlp = English() + nlp.add_pipe(nlp.create_pipe("sentencizer")) + for doc in nlp.pipe(texts): + assert doc.is_sentenced + sent_starts = [t.is_sent_start for t in doc] + assert sent_starts == [True, False, True, False, False, False, False] + assert len(list(doc.sents)) == 2 + + @pytest.mark.parametrize( "words,sent_starts,n_sents", [