Support exclude in Doc.from_docs (#10689)

* Support exclude in Doc.from_docs

* Update API docs

* Add new tag to docs
This commit is contained in:
Adriane Boyd 2022-04-25 18:19:03 +02:00 committed by GitHub
parent 3b208197c3
commit 455f089c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 37 deletions

View File

@ -1,6 +1,7 @@
import weakref
import numpy
from numpy.testing import assert_array_equal
import pytest
from thinc.api import NumpyOps, get_current_ops
@ -634,6 +635,14 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
# can exclude spans
m_doc = Doc.from_docs(en_docs, exclude=["spans"])
assert "group" not in m_doc.spans
# can exclude user_data
m_doc = Doc.from_docs(en_docs, exclude=["user_data"])
assert m_doc.user_data == {}
# can merge empty docs
doc = Doc.from_docs([en_tokenizer("")] * 10)
@ -647,6 +656,20 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert "group" in m_doc.spans
assert len(m_doc.spans["group"]) == 0
# with tensor
ops = get_current_ops()
for doc in en_docs:
doc.tensor = ops.asarray([[len(t.text), 0.0] for t in doc])
m_doc = Doc.from_docs(en_docs)
assert_array_equal(
ops.to_numpy(m_doc.tensor),
ops.to_numpy(ops.xp.vstack([doc.tensor for doc in en_docs if len(doc)])),
)
# can exclude tensor
m_doc = Doc.from_docs(en_docs, exclude=["tensor"])
assert m_doc.tensor.shape == (0,)
def test_doc_api_from_docs_ents(en_tokenizer):
texts = ["Merging the docs is fun.", "They don't think alike."]

View File

@ -11,7 +11,7 @@ from enum import Enum
import itertools
import numpy
import srsly
from thinc.api import get_array_module
from thinc.api import get_array_module, get_current_ops
from thinc.util import copy_array
import warnings
@ -1108,14 +1108,19 @@ cdef class Doc:
return self
@staticmethod
def from_docs(docs, ensure_whitespace=True, attrs=None):
def from_docs(docs, ensure_whitespace=True, attrs=None, *, exclude=tuple()):
"""Concatenate multiple Doc objects to form a new one. Raises an error
if the `Doc` objects do not all share the same `Vocab`.
docs (list): A list of Doc objects.
ensure_whitespace (bool): Insert a space between two adjacent docs whenever the first doc does not end in whitespace.
attrs (list): Optional list of attribute ID ints or attribute name strings.
RETURNS (Doc): A doc that contains the concatenated docs, or None if no docs were given.
ensure_whitespace (bool): Insert a space between two adjacent docs
whenever the first doc does not end in whitespace.
attrs (list): Optional list of attribute ID ints or attribute name
strings.
exclude (Iterable[str]): Doc attributes to exclude. Supported
attributes: `spans`, `tensor`, `user_data`.
RETURNS (Doc): A doc that contains the concatenated docs, or None if no
docs were given.
DOCS: https://spacy.io/api/doc#from_docs
"""
@ -1145,6 +1150,7 @@ cdef class Doc:
concat_words.extend(t.text for t in doc)
concat_spaces.extend(bool(t.whitespace_) for t in doc)
if "user_data" not in exclude:
for key, value in doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
@ -1157,6 +1163,7 @@ cdef class Doc:
warnings.warn(Warnings.W101.format(name=name))
else:
warnings.warn(Warnings.W102.format(key=key, value=value))
if "spans" not in exclude:
for key in doc.spans:
# if a spans key is in any doc, include it in the merged doc
# even if it is empty
@ -1210,6 +1217,10 @@ cdef class Doc:
else:
raise ValueError(Errors.E873.format(key=key, text=text))
if "tensor" not in exclude and any(len(doc) for doc in docs):
ops = get_current_ops()
concat_doc.tensor = ops.xp.vstack([ops.asarray(doc.tensor) for doc in docs if len(doc)])
return concat_doc
def get_lca_matrix(self):

View File

@ -304,7 +304,8 @@ ancestor is found, e.g. if span excludes a necessary ancestor.
## Doc.has_annotation {#has_annotation tag="method"}
Check whether the doc contains annotation on a [`Token` attribute](/api/token#attributes).
Check whether the doc contains annotation on a
[`Token` attribute](/api/token#attributes).
<Infobox title="Changed in v3.0" variant="warning">
@ -399,10 +400,12 @@ Concatenate multiple `Doc` objects to form a new one. Raises an error if the
> ```
| Name | Description |
| ------------------- | ----------------------------------------------------------------------------------------------------------------- |
| -------------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
| `docs` | A list of `Doc` objects. ~~List[Doc]~~ |
| `ensure_whitespace` | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ~~bool~~ |
| `attrs` | Optional list of attribute ID ints or attribute name strings. ~~Optional[List[Union[str, int]]]~~ |
| _keyword-only_ | |
| `exclude` <Tag variant="new">3.3</Tag> | String names of Doc attributes to exclude. Supported: `spans`, `tensor`, `user_data`. ~~Iterable[str]~~ |
| **RETURNS** | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. ~~Optional[Doc]~~ |
## Doc.to_disk {#to_disk tag="method" new="2"}