From 69e20ce03d5a819db9d868f0b8872f60f71a6a70 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 31 Mar 2023 13:43:51 +0200 Subject: [PATCH] Fix pickle for ngram suggester (#12486) --- spacy/pipeline/spancat.py | 58 +++++++++++++++------------- spacy/tests/pipeline/test_spancat.py | 20 +++++++++- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index ff68a3703..5a087e42a 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,5 +1,6 @@ from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union from dataclasses import dataclass +from functools import partial from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer from thinc.types import Ragged, Ints2d, Floats2d @@ -82,39 +83,42 @@ class Suggester(Protocol): ... +def ngram_suggester( + docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None +) -> Ragged: + if ops is None: + ops = get_current_ops() + spans = [] + lengths = [] + for doc in docs: + starts = ops.xp.arange(len(doc), dtype="i") + starts = starts.reshape((-1, 1)) + length = 0 + for size in sizes: + if size <= len(doc): + starts_size = starts[: len(doc) - (size - 1)] + spans.append(ops.xp.hstack((starts_size, starts_size + size))) + length += spans[-1].shape[0] + if spans: + assert spans[-1].ndim == 2, spans[-1].shape + lengths.append(length) + lengths_array = ops.asarray1i(lengths) + if len(spans) > 0: + output = Ragged(ops.xp.vstack(spans), lengths_array) + else: + output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) + + assert output.dataXd.ndim == 2 + return output + + @registry.misc("spacy.ngram_suggester.v1") def build_ngram_suggester(sizes: List[int]) -> Suggester: """Suggest all spans of the given lengths. Spans are returned as a ragged array of integers. The array has two columns, indicating the start and end position.""" - def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged: - if ops is None: - ops = get_current_ops() - spans = [] - lengths = [] - for doc in docs: - starts = ops.xp.arange(len(doc), dtype="i") - starts = starts.reshape((-1, 1)) - length = 0 - for size in sizes: - if size <= len(doc): - starts_size = starts[: len(doc) - (size - 1)] - spans.append(ops.xp.hstack((starts_size, starts_size + size))) - length += spans[-1].shape[0] - if spans: - assert spans[-1].ndim == 2, spans[-1].shape - lengths.append(length) - lengths_array = ops.asarray1i(lengths) - if len(spans) > 0: - output = Ragged(ops.xp.vstack(spans), lengths_array) - else: - output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) - - assert output.dataXd.ndim == 2 - return output - - return ngram_suggester + return partial(ngram_suggester, sizes=sizes) @registry.misc("spacy.ngram_range_suggester.v1") diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index b06505a6d..199ef2b2a 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -1,7 +1,7 @@ import pytest import numpy from numpy.testing import assert_array_equal, assert_almost_equal -from thinc.api import get_current_ops, Ragged +from thinc.api import get_current_ops, NumpyOps, Ragged from spacy import util from spacy.lang.en import English @@ -577,3 +577,21 @@ def test_set_candidates(name): assert len(docs[0].spans["candidates"]) == 9 assert docs[0].spans["candidates"][0].text == "Just" assert docs[0].spans["candidates"][4].text == "Just a" + + +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +@pytest.mark.parametrize("n_process", [1, 2]) +def test_spancat_multiprocessing(name, n_process): + if isinstance(get_current_ops, NumpyOps) or n_process < 2: + nlp = Language() + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) + train_examples = make_examples(nlp) + nlp.initialize(get_examples=lambda: train_examples) + texts = [ + "Just a sentence.", + "I like London and Berlin", + "I like Berlin", + "I eat ham.", + ] + docs = list(nlp.pipe(texts, n_process=n_process)) + assert len(docs) == len(texts)