diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 858c7cbb6..19b554572 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -1,6 +1,7 @@ import weakref import numpy +from numpy.testing import assert_array_equal import pytest from thinc.api import NumpyOps, get_current_ops @@ -634,6 +635,14 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer): assert "group" in m_doc.spans assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) + # can exclude spans + m_doc = Doc.from_docs(en_docs, exclude=["spans"]) + assert "group" not in m_doc.spans + + # can exclude user_data + m_doc = Doc.from_docs(en_docs, exclude=["user_data"]) + assert m_doc.user_data == {} + # can merge empty docs doc = Doc.from_docs([en_tokenizer("")] * 10) @@ -647,6 +656,20 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer): assert "group" in m_doc.spans assert len(m_doc.spans["group"]) == 0 + # with tensor + ops = get_current_ops() + for doc in en_docs: + doc.tensor = ops.asarray([[len(t.text), 0.0] for t in doc]) + m_doc = Doc.from_docs(en_docs) + assert_array_equal( + ops.to_numpy(m_doc.tensor), + ops.to_numpy(ops.xp.vstack([doc.tensor for doc in en_docs if len(doc)])), + ) + + # can exclude tensor + m_doc = Doc.from_docs(en_docs, exclude=["tensor"]) + assert m_doc.tensor.shape == (0,) + def test_doc_api_from_docs_ents(en_tokenizer): texts = ["Merging the docs is fun.", "They don't think alike."] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 1a48705fd..c36e3a02f 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -11,7 +11,7 @@ from enum import Enum import itertools import numpy import srsly -from thinc.api import get_array_module +from thinc.api import get_array_module, get_current_ops from thinc.util import copy_array import warnings @@ -1108,14 +1108,19 @@ cdef class Doc: return self @staticmethod - def from_docs(docs, ensure_whitespace=True, attrs=None): + def from_docs(docs, ensure_whitespace=True, attrs=None, *, exclude=tuple()): """Concatenate multiple Doc objects to form a new one. Raises an error if the `Doc` objects do not all share the same `Vocab`. docs (list): A list of Doc objects. - ensure_whitespace (bool): Insert a space between two adjacent docs whenever the first doc does not end in whitespace. - attrs (list): Optional list of attribute ID ints or attribute name strings. - RETURNS (Doc): A doc that contains the concatenated docs, or None if no docs were given. + ensure_whitespace (bool): Insert a space between two adjacent docs + whenever the first doc does not end in whitespace. + attrs (list): Optional list of attribute ID ints or attribute name + strings. + exclude (Iterable[str]): Doc attributes to exclude. Supported + attributes: `spans`, `tensor`, `user_data`. + RETURNS (Doc): A doc that contains the concatenated docs, or None if no + docs were given. DOCS: https://spacy.io/api/doc#from_docs """ @@ -1145,31 +1150,33 @@ cdef class Doc: concat_words.extend(t.text for t in doc) concat_spaces.extend(bool(t.whitespace_) for t in doc) - for key, value in doc.user_data.items(): - if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key - if start is not None or end is not None: - start += char_offset - if end is not None: - end += char_offset - concat_user_data[(data_type, name, start, end)] = copy.copy(value) + if "user_data" not in exclude: + for key, value in doc.user_data.items(): + if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": + data_type, name, start, end = key + if start is not None or end is not None: + start += char_offset + if end is not None: + end += char_offset + concat_user_data[(data_type, name, start, end)] = copy.copy(value) + else: + warnings.warn(Warnings.W101.format(name=name)) else: - warnings.warn(Warnings.W101.format(name=name)) - else: - warnings.warn(Warnings.W102.format(key=key, value=value)) - for key in doc.spans: - # if a spans key is in any doc, include it in the merged doc - # even if it is empty - if key not in concat_spans: - concat_spans[key] = [] - for span in doc.spans[key]: - concat_spans[key].append(( - span.start_char + char_offset, - span.end_char + char_offset, - span.label, - span.kb_id, - span.text, # included as a check - )) + warnings.warn(Warnings.W102.format(key=key, value=value)) + if "spans" not in exclude: + for key in doc.spans: + # if a spans key is in any doc, include it in the merged doc + # even if it is empty + if key not in concat_spans: + concat_spans[key] = [] + for span in doc.spans[key]: + concat_spans[key].append(( + span.start_char + char_offset, + span.end_char + char_offset, + span.label, + span.kb_id, + span.text, # included as a check + )) char_offset += len(doc.text) if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_): char_offset += 1 @@ -1210,6 +1217,10 @@ cdef class Doc: else: raise ValueError(Errors.E873.format(key=key, text=text)) + if "tensor" not in exclude and any(len(doc) for doc in docs): + ops = get_current_ops() + concat_doc.tensor = ops.xp.vstack([ops.asarray(doc.tensor) for doc in docs if len(doc)]) + return concat_doc def get_lca_matrix(self): diff --git a/website/docs/api/doc.md b/website/docs/api/doc.md index c28509ab0..c929a4a06 100644 --- a/website/docs/api/doc.md +++ b/website/docs/api/doc.md @@ -34,7 +34,7 @@ Construct a `Doc` object. The most common way to get a `Doc` object is via the | Name | Description | | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `vocab` | A storage container for lexical types. ~~Vocab~~ | -| `words` | A list of strings or integer hash values to add to the document as words. ~~Optional[List[Union[str,int]]]~~ | +| `words` | A list of strings or integer hash values to add to the document as words. ~~Optional[List[Union[str,int]]]~~ | | `spaces` | A list of boolean values indicating whether each word has a subsequent space. Must have the same length as `words`, if specified. Defaults to a sequence of `True`. ~~Optional[List[bool]]~~ | | _keyword-only_ | | | `user\_data` | Optional extra data to attach to the Doc. ~~Dict~~ | @@ -304,7 +304,8 @@ ancestor is found, e.g. if span excludes a necessary ancestor. ## Doc.has_annotation {#has_annotation tag="method"} -Check whether the doc contains annotation on a [`Token` attribute](/api/token#attributes). +Check whether the doc contains annotation on a +[`Token` attribute](/api/token#attributes). @@ -398,12 +399,14 @@ Concatenate multiple `Doc` objects to form a new one. Raises an error if the > [str(ent) for doc in docs for ent in doc.ents] > ``` -| Name | Description | -| ------------------- | ----------------------------------------------------------------------------------------------------------------- | -| `docs` | A list of `Doc` objects. ~~List[Doc]~~ | -| `ensure_whitespace` | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ~~bool~~ | -| `attrs` | Optional list of attribute ID ints or attribute name strings. ~~Optional[List[Union[str, int]]]~~ | -| **RETURNS** | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. ~~Optional[Doc]~~ | +| Name | Description | +| -------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | +| `docs` | A list of `Doc` objects. ~~List[Doc]~~ | +| `ensure_whitespace` | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ~~bool~~ | +| `attrs` | Optional list of attribute ID ints or attribute name strings. ~~Optional[List[Union[str, int]]]~~ | +| _keyword-only_ | | +| `exclude` 3.3 | String names of Doc attributes to exclude. Supported: `spans`, `tensor`, `user_data`. ~~Iterable[str]~~ | +| **RETURNS** | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. ~~Optional[Doc]~~ | ## Doc.to_disk {#to_disk tag="method" new="2"}