From c9baf9d196cba07fe1b1c636bcab3c80c6b81b44 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 15 Nov 2021 12:40:55 +0100 Subject: [PATCH] Fix spancat for empty docs and zero suggestions (#9654) * Fix spancat for empty docs and zero suggestions * Use ops.xp.zeros in test --- spacy/ml/extract_spans.py | 10 +++++++-- spacy/pipeline/spancat.py | 2 +- spacy/tests/pipeline/test_spancat.py | 31 +++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index 9bc972032..edc86ff9c 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -28,7 +28,13 @@ def forward( X, spans = source_spans assert spans.dataXd.ndim == 2 indices = _get_span_indices(ops, spans, X.lengths) - Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index] + if len(indices) > 0: + Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index] + else: + Y = Ragged( + ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype), + ops.xp.zeros((len(X.lengths),), dtype="i"), + ) x_shape = X.dataXd.shape x_lengths = X.lengths @@ -53,7 +59,7 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d: for j in range(spans_i.shape[0]): indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index] offset += length - return ops.flatten(indices) + return ops.flatten(indices, dtype="i", ndim_if_empty=1) def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 5b84ce8fb..829def1eb 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -78,7 +78,7 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester: if len(spans) > 0: output = Ragged(ops.xp.vstack(spans), lengths_array) else: - output = Ragged(ops.xp.zeros((0, 0)), lengths_array) + output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) assert output.dataXd.ndim == 2 return output diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 5c3a9d27d..2f7e952d3 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -1,7 +1,7 @@ import pytest import numpy from numpy.testing import assert_array_equal, assert_almost_equal -from thinc.api import get_current_ops +from thinc.api import get_current_ops, Ragged from spacy import util from spacy.lang.en import English @@ -29,6 +29,7 @@ TRAIN_DATA_OVERLAPPING = [ "I like London and Berlin", {"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}}, ), + ("", {"spans": {SPAN_KEY: []}}), ] @@ -365,3 +366,31 @@ def test_overfitting_IO_overlapping(): "London and Berlin", } assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} + + +def test_zero_suggestions(): + # Test with a suggester that returns 0 suggestions + + @registry.misc("test_zero_suggester") + def make_zero_suggester(): + def zero_suggester(docs, *, ops=None): + if ops is None: + ops = get_current_ops() + return Ragged( + ops.xp.zeros((0, 0), dtype="i"), ops.xp.zeros((len(docs),), dtype="i") + ) + + return zero_suggester + + fix_random_seed(0) + nlp = English() + spancat = nlp.add_pipe( + "spancat", + config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY}, + ) + train_examples = make_examples(nlp) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + assert spancat.model.get_dim("nO") == 2 + assert set(spancat.labels) == {"LOC", "PERSON"} + + nlp.update(train_examples, sgd=optimizer)