mirror of https://github.com/explosion/spaCy.git
Add test for Language.pipe as_tuples with custom error handlers (#9608)
* make nlp.pipe() return None docs when no exceptions are (re-)raised during error handling * Remove changes other than as_tuples test * Only check warning count for one process * Fix types * Format Co-authored-by: Xi Bai <xi.bai.ed@gmail.com>
This commit is contained in:
parent
79cea03983
commit
db0d8c56d0
|
@ -1537,8 +1537,7 @@ class Language:
|
||||||
yield (doc, context)
|
yield (doc, context)
|
||||||
return
|
return
|
||||||
|
|
||||||
# At this point, we know that we're dealing with an iterable of plain texts
|
texts = cast(Iterable[Union[str, Doc]], texts)
|
||||||
texts = cast(Iterable[str], texts)
|
|
||||||
|
|
||||||
# Set argument defaults
|
# Set argument defaults
|
||||||
if n_process == -1:
|
if n_process == -1:
|
||||||
|
@ -1592,7 +1591,7 @@ class Language:
|
||||||
|
|
||||||
def _multiprocessing_pipe(
|
def _multiprocessing_pipe(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[Union[str, Doc]],
|
||||||
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
||||||
n_process: int,
|
n_process: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
|
|
@ -255,6 +255,38 @@ def test_language_pipe_error_handler_custom(en_vocab, n_process):
|
||||||
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
|
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_process", [1, 2])
|
||||||
|
def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process):
|
||||||
|
"""Test the error handling of nlp.pipe with input as tuples"""
|
||||||
|
Language.component("my_evil_component", func=evil_component)
|
||||||
|
ops = get_current_ops()
|
||||||
|
if isinstance(ops, NumpyOps) or n_process < 2:
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe("my_evil_component")
|
||||||
|
texts = [
|
||||||
|
("TEXT 111", 111),
|
||||||
|
("TEXT 222", 222),
|
||||||
|
("TEXT 333", 333),
|
||||||
|
("TEXT 342", 342),
|
||||||
|
("TEXT 666", 666),
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
list(nlp.pipe(texts, as_tuples=True))
|
||||||
|
nlp.set_error_handler(warn_error)
|
||||||
|
logger = logging.getLogger("spacy")
|
||||||
|
with mock.patch.object(logger, "warning") as mock_warning:
|
||||||
|
tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process))
|
||||||
|
# HACK/TODO? the warnings in child processes don't seem to be
|
||||||
|
# detected by the mock logger
|
||||||
|
if n_process == 1:
|
||||||
|
mock_warning.assert_called()
|
||||||
|
assert mock_warning.call_count == 2
|
||||||
|
assert len(tuples) + mock_warning.call_count == len(texts)
|
||||||
|
assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111)
|
||||||
|
assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333)
|
||||||
|
assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_process", [1, 2])
|
@pytest.mark.parametrize("n_process", [1, 2])
|
||||||
def test_language_pipe_error_handler_pipe(en_vocab, n_process):
|
def test_language_pipe_error_handler_pipe(en_vocab, n_process):
|
||||||
"""Test the error handling of a component's pipe method"""
|
"""Test the error handling of a component's pipe method"""
|
||||||
|
@ -515,19 +547,19 @@ def test_spacy_blank():
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"lang,target",
|
"lang,target",
|
||||||
[
|
[
|
||||||
('en', 'en'),
|
("en", "en"),
|
||||||
('fra', 'fr'),
|
("fra", "fr"),
|
||||||
('fre', 'fr'),
|
("fre", "fr"),
|
||||||
('iw', 'he'),
|
("iw", "he"),
|
||||||
('mo', 'ro'),
|
("mo", "ro"),
|
||||||
('mul', 'xx'),
|
("mul", "xx"),
|
||||||
('no', 'nb'),
|
("no", "nb"),
|
||||||
('pt-BR', 'pt'),
|
("pt-BR", "pt"),
|
||||||
('xx', 'xx'),
|
("xx", "xx"),
|
||||||
('zh-Hans', 'zh'),
|
("zh-Hans", "zh"),
|
||||||
('zh-Hant', None),
|
("zh-Hant", None),
|
||||||
('zxx', None)
|
("zxx", None),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_language_matching(lang, target):
|
def test_language_matching(lang, target):
|
||||||
"""
|
"""
|
||||||
|
@ -540,17 +572,17 @@ def test_language_matching(lang, target):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"lang,target",
|
"lang,target",
|
||||||
[
|
[
|
||||||
('en', 'en'),
|
("en", "en"),
|
||||||
('fra', 'fr'),
|
("fra", "fr"),
|
||||||
('fre', 'fr'),
|
("fre", "fr"),
|
||||||
('iw', 'he'),
|
("iw", "he"),
|
||||||
('mo', 'ro'),
|
("mo", "ro"),
|
||||||
('mul', 'xx'),
|
("mul", "xx"),
|
||||||
('no', 'nb'),
|
("no", "nb"),
|
||||||
('pt-BR', 'pt'),
|
("pt-BR", "pt"),
|
||||||
('xx', 'xx'),
|
("xx", "xx"),
|
||||||
('zh-Hans', 'zh'),
|
("zh-Hans", "zh"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_blank_languages(lang, target):
|
def test_blank_languages(lang, target):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue