mirror of https://github.com/explosion/spaCy.git
Fix pickle for ngram suggester (#12486)
This commit is contained in:
parent
140d53649d
commit
69e20ce03d
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d
|
||||
|
@ -82,39 +83,42 @@ class Suggester(Protocol):
|
|||
...
|
||||
|
||||
|
||||
def ngram_suggester(
|
||||
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
starts = ops.xp.arange(len(doc), dtype="i")
|
||||
starts = starts.reshape((-1, 1))
|
||||
length = 0
|
||||
for size in sizes:
|
||||
if size <= len(doc):
|
||||
starts_size = starts[: len(doc) - (size - 1)]
|
||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||
length += spans[-1].shape[0]
|
||||
if spans:
|
||||
assert spans[-1].ndim == 2, spans[-1].shape
|
||||
lengths.append(length)
|
||||
lengths_array = ops.asarray1i(lengths)
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
|
||||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
position."""
|
||||
|
||||
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
starts = ops.xp.arange(len(doc), dtype="i")
|
||||
starts = starts.reshape((-1, 1))
|
||||
length = 0
|
||||
for size in sizes:
|
||||
if size <= len(doc):
|
||||
starts_size = starts[: len(doc) - (size - 1)]
|
||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||
length += spans[-1].shape[0]
|
||||
if spans:
|
||||
assert spans[-1].ndim == 2, spans[-1].shape
|
||||
lengths.append(length)
|
||||
lengths_array = ops.asarray1i(lengths)
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
|
||||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
return ngram_suggester
|
||||
return partial(ngram_suggester, sizes=sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||
from thinc.api import get_current_ops, Ragged
|
||||
from thinc.api import get_current_ops, NumpyOps, Ragged
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
|
@ -577,3 +577,21 @@ def test_set_candidates(name):
|
|||
assert len(docs[0].spans["candidates"]) == 9
|
||||
assert docs[0].spans["candidates"][0].text == "Just"
|
||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_spancat_multiprocessing(name, n_process):
|
||||
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
texts = [
|
||||
"Just a sentence.",
|
||||
"I like London and Berlin",
|
||||
"I like Berlin",
|
||||
"I eat ham.",
|
||||
]
|
||||
docs = list(nlp.pipe(texts, n_process=n_process))
|
||||
assert len(docs) == len(texts)
|
||||
|
|
Loading…
Reference in New Issue