From 2f0bb7792081f9f0ab8caddaddf305244d7775d5 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Wed, 22 Sep 2021 09:41:05 +0200 Subject: [PATCH] Accept Doc input in pipelines (#9069) * Accept Doc input in pipelines Allow `Doc` input to `Language.__call__` and `Language.pipe`, which skips `Language.make_doc` and passes the doc directly to the pipeline. * ensure_doc helper function * avoid running multiple processes on GPU * Update spacy/tests/test_language.py Co-authored-by: svlandeg --- spacy/errors.py | 1 + spacy/language.py | 34 +++++++++++++++++++++++----------- spacy/tests/test_language.py | 26 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 9264ca6d1..f1c068793 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -521,6 +521,7 @@ class Errors: E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.") # New errors added in v3.x + E866 = ("Expected a string or 'Doc' as input, but got: {type}.") E867 = ("The 'textcat' component requires at least two labels because it " "uses mutually exclusive classes where exactly one label is True " "for each doc. For binary classification tasks, you can use two " diff --git a/spacy/language.py b/spacy/language.py index a8cad1259..540937e66 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -968,7 +968,7 @@ class Language: def __call__( self, - text: str, + text: Union[str, Doc], *, disable: Iterable[str] = SimpleFrozenList(), component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, @@ -977,7 +977,9 @@ class Language: and can contain arbitrary whitespace. Alignment into the original string is preserved. - text (str): The text to be processed. + text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`, + the doc will be passed directly to the pipeline, skipping + `Language.make_doc`. disable (list): Names of the pipeline components to disable. component_cfg (Dict[str, dict]): An optional dictionary with extra keyword arguments for specific components. @@ -985,7 +987,7 @@ class Language: DOCS: https://spacy.io/api/language#call """ - doc = self.make_doc(text) + doc = self._ensure_doc(text) if component_cfg is None: component_cfg = {} for name, proc in self.pipeline: @@ -1069,6 +1071,14 @@ class Language: ) return self.tokenizer(text) + def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc: + """Create a Doc if need be, or raise an error if the input is not a Doc or a string.""" + if isinstance(doc_like, Doc): + return doc_like + if isinstance(doc_like, str): + return self.make_doc(doc_like) + raise ValueError(Errors.E866.format(type=type(doc_like))) + def update( self, examples: Iterable[Example], @@ -1437,7 +1447,7 @@ class Language: @overload def pipe( self, - texts: Iterable[Tuple[str, _AnyContext]], + texts: Iterable[Tuple[Union[str, Doc], _AnyContext]], *, as_tuples: bool = ..., batch_size: Optional[int] = ..., @@ -1449,7 +1459,7 @@ class Language: def pipe( # noqa: F811 self, - texts: Iterable[str], + texts: Iterable[Union[str, Doc]], *, as_tuples: bool = False, batch_size: Optional[int] = None, @@ -1459,7 +1469,8 @@ class Language: ) -> Iterator[Doc]: """Process texts as a stream, and yield `Doc` objects in order. - texts (Iterable[str]): A sequence of texts to process. + texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to + process. as_tuples (bool): If set to True, inputs should be a sequence of (text, context) tuples. Output will then be a sequence of (doc, context) tuples. Defaults to False. @@ -1515,7 +1526,7 @@ class Language: docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size) else: # if n_process == 1, no processes are forked. - docs = (self.make_doc(text) for text in texts) + docs = (self._ensure_doc(text) for text in texts) for pipe in pipes: docs = pipe(docs) for doc in docs: @@ -1549,7 +1560,7 @@ class Language: procs = [ mp.Process( target=_apply_pipes, - args=(self.make_doc, pipes, rch, sch, Underscore.get_state()), + args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()), ) for rch, sch in zip(texts_q, bytedocs_send_ch) ] @@ -2084,7 +2095,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]: def _apply_pipes( - make_doc: Callable[[str], Doc], + ensure_doc: Callable[[Union[str, Doc]], Doc], pipes: Iterable[Callable[[Doc], Doc]], receiver, sender, @@ -2092,7 +2103,8 @@ def _apply_pipes( ) -> None: """Worker for Language.pipe - make_doc (Callable[[str,] Doc]): Function to create Doc from text. + ensure_doc (Callable[[Union[str, Doc]], Doc]): Function to create Doc from text + or raise an error if the input is neither a Doc nor a string. pipes (Iterable[Callable[[Doc], Doc]]): The components to apply. receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()` @@ -2105,7 +2117,7 @@ def _apply_pipes( while True: try: texts = receiver.get() - docs = (make_doc(text) for text in texts) + docs = (ensure_doc(text) for text in texts) for pipe in pipes: docs = pipe(docs) # Connection does not accept unpickable objects, so send list. diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index c911b8d81..e3c25fece 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -528,3 +528,29 @@ def test_language_source_and_vectors(nlp2): assert long_string in nlp2.vocab.strings # vectors should remain unmodified assert nlp.vocab.vectors.to_bytes() == vectors_bytes + + +@pytest.mark.parametrize("n_process", [1, 2]) +def test_pass_doc_to_pipeline(nlp, n_process): + texts = ["cats", "dogs", "guinea pigs"] + docs = [nlp.make_doc(text) for text in texts] + assert not any(len(doc.cats) for doc in docs) + doc = nlp(docs[0]) + assert doc.text == texts[0] + assert len(doc.cats) > 0 + if isinstance(get_current_ops(), NumpyOps) or n_process < 2: + docs = nlp.pipe(docs, n_process=n_process) + assert [doc.text for doc in docs] == texts + assert all(len(doc.cats) for doc in docs) + + +def test_invalid_arg_to_pipeline(nlp): + str_list = ["This is a text.", "This is another."] + with pytest.raises(ValueError): + nlp(str_list) # type: ignore + assert len(list(nlp.pipe(str_list))) == 2 + int_list = [1, 2, 3] + with pytest.raises(ValueError): + list(nlp.pipe(int_list)) # type: ignore + with pytest.raises(ValueError): + nlp(int_list) # type: ignore