diff --git a/spacy/language.py b/spacy/language.py index 55c9912cc..aa57989ac 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1537,8 +1537,7 @@ class Language: yield (doc, context) return - # At this point, we know that we're dealing with an iterable of plain texts - texts = cast(Iterable[str], texts) + texts = cast(Iterable[Union[str, Doc]], texts) # Set argument defaults if n_process == -1: @@ -1592,7 +1591,7 @@ class Language: def _multiprocessing_pipe( self, - texts: Iterable[str], + texts: Iterable[Union[str, Doc]], pipes: Iterable[Callable[..., Iterator[Doc]]], n_process: int, batch_size: int, diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 444b1c83e..c5fdc8eb0 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -255,6 +255,38 @@ def test_language_pipe_error_handler_custom(en_vocab, n_process): assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"] +@pytest.mark.parametrize("n_process", [1, 2]) +def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process): + """Test the error handling of nlp.pipe with input as tuples""" + Language.component("my_evil_component", func=evil_component) + ops = get_current_ops() + if isinstance(ops, NumpyOps) or n_process < 2: + nlp = English() + nlp.add_pipe("my_evil_component") + texts = [ + ("TEXT 111", 111), + ("TEXT 222", 222), + ("TEXT 333", 333), + ("TEXT 342", 342), + ("TEXT 666", 666), + ] + with pytest.raises(ValueError): + list(nlp.pipe(texts, as_tuples=True)) + nlp.set_error_handler(warn_error) + logger = logging.getLogger("spacy") + with mock.patch.object(logger, "warning") as mock_warning: + tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process)) + # HACK/TODO? the warnings in child processes don't seem to be + # detected by the mock logger + if n_process == 1: + mock_warning.assert_called() + assert mock_warning.call_count == 2 + assert len(tuples) + mock_warning.call_count == len(texts) + assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111) + assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333) + assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666) + + @pytest.mark.parametrize("n_process", [1, 2]) def test_language_pipe_error_handler_pipe(en_vocab, n_process): """Test the error handling of a component's pipe method""" @@ -515,19 +547,19 @@ def test_spacy_blank(): @pytest.mark.parametrize( "lang,target", [ - ('en', 'en'), - ('fra', 'fr'), - ('fre', 'fr'), - ('iw', 'he'), - ('mo', 'ro'), - ('mul', 'xx'), - ('no', 'nb'), - ('pt-BR', 'pt'), - ('xx', 'xx'), - ('zh-Hans', 'zh'), - ('zh-Hant', None), - ('zxx', None) - ] + ("en", "en"), + ("fra", "fr"), + ("fre", "fr"), + ("iw", "he"), + ("mo", "ro"), + ("mul", "xx"), + ("no", "nb"), + ("pt-BR", "pt"), + ("xx", "xx"), + ("zh-Hans", "zh"), + ("zh-Hant", None), + ("zxx", None), + ], ) def test_language_matching(lang, target): """ @@ -540,17 +572,17 @@ def test_language_matching(lang, target): @pytest.mark.parametrize( "lang,target", [ - ('en', 'en'), - ('fra', 'fr'), - ('fre', 'fr'), - ('iw', 'he'), - ('mo', 'ro'), - ('mul', 'xx'), - ('no', 'nb'), - ('pt-BR', 'pt'), - ('xx', 'xx'), - ('zh-Hans', 'zh'), - ] + ("en", "en"), + ("fra", "fr"), + ("fre", "fr"), + ("iw", "he"), + ("mo", "ro"), + ("mul", "xx"), + ("no", "nb"), + ("pt-BR", "pt"), + ("xx", "xx"), + ("zh-Hans", "zh"), + ], ) def test_blank_languages(lang, target): """