diff --git a/spacy/errors.py b/spacy/errors.py index 5fe550145..84c407422 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -191,6 +191,7 @@ class Warnings(metaclass=ErrorsWithCodes): "lead to errors.") W115 = ("Skipping {method}: the floret vector table cannot be modified. " "Vectors are calculated from character ngrams.") + W116 = ("Unable to clean attribute '{attr}'.") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/pipeline/functions.py b/spacy/pipeline/functions.py index f0a75dc2c..c005395bf 100644 --- a/spacy/pipeline/functions.py +++ b/spacy/pipeline/functions.py @@ -1,6 +1,8 @@ from typing import Dict, Any import srsly +import warnings +from ..errors import Warnings from ..language import Language from ..matcher import Matcher from ..tokens import Doc @@ -136,3 +138,65 @@ class TokenSplitter: "cfg": lambda p: self._set_config(srsly.read_json(p)), } util.from_disk(path, serializers, []) + + +@Language.factory( + "doc_cleaner", + default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True}, +) +def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool): + return DocCleaner(attrs, silent=silent) + + +class DocCleaner: + def __init__(self, attrs: Dict[str, Any], *, silent: bool = True): + self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent} + + def __call__(self, doc: Doc) -> Doc: + attrs: dict = self.cfg["attrs"] + silent: bool = self.cfg["silent"] + for attr, value in attrs.items(): + obj = doc + parts = attr.split(".") + skip = False + for part in parts[:-1]: + if hasattr(obj, part): + obj = getattr(obj, part) + else: + skip = True + if not silent: + warnings.warn(Warnings.W116.format(attr=attr)) + if not skip: + if hasattr(obj, parts[-1]): + setattr(obj, parts[-1], value) + else: + if not silent: + warnings.warn(Warnings.W116.format(attr=attr)) + return doc + + def to_bytes(self, **kwargs): + serializers = { + "cfg": lambda: srsly.json_dumps(self.cfg), + } + return util.to_bytes(serializers, []) + + def from_bytes(self, data, **kwargs): + deserializers = { + "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), + } + util.from_bytes(data, deserializers, []) + return self + + def to_disk(self, path, **kwargs): + path = util.ensure_path(path) + serializers = { + "cfg": lambda p: srsly.write_json(p, self.cfg), + } + return util.to_disk(path, serializers, []) + + def from_disk(self, path, **kwargs): + path = util.ensure_path(path) + serializers = { + "cfg": lambda p: self.cfg.update(srsly.read_json(p)), + } + util.from_disk(path, serializers, []) diff --git a/spacy/tests/pipeline/test_functions.py b/spacy/tests/pipeline/test_functions.py index 454d7b08b..e4adfe2fe 100644 --- a/spacy/tests/pipeline/test_functions.py +++ b/spacy/tests/pipeline/test_functions.py @@ -3,6 +3,8 @@ from spacy.pipeline.functions import merge_subtokens from spacy.language import Language from spacy.tokens import Span, Doc +from ..doc.test_underscore import clean_underscore # noqa: F401 + @pytest.fixture def doc(en_vocab): @@ -74,3 +76,26 @@ def test_token_splitter(): "i", ] assert all(len(t.text) <= token_splitter.split_length for t in doc) + + +@pytest.mark.usefixtures("clean_underscore") +def test_factories_doc_cleaner(): + nlp = Language() + nlp.add_pipe("doc_cleaner") + doc = nlp.make_doc("text") + doc.tensor = [1, 2, 3] + doc = nlp(doc) + assert doc.tensor is None + + nlp = Language() + nlp.add_pipe("doc_cleaner", config={"silent": False}) + with pytest.warns(UserWarning): + doc = nlp("text") + + Doc.set_extension("test_attr", default=-1) + nlp = Language() + nlp.add_pipe("doc_cleaner", config={"attrs": {"_.test_attr": 0}}) + doc = nlp.make_doc("text") + doc._.test_attr = 100 + doc = nlp(doc) + assert doc._.test_attr == 0 diff --git a/website/docs/api/pipeline-functions.md b/website/docs/api/pipeline-functions.md index a776eca9b..ff19d3e71 100644 --- a/website/docs/api/pipeline-functions.md +++ b/website/docs/api/pipeline-functions.md @@ -130,3 +130,25 @@ exceed the transformer model max length. | `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ | | `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ | | **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ | + +## doc_cleaner {#doc_cleaner tag="function" new="3.2.1"} + +Clean up `Doc` attributes. Intended for use at the end of pipelines with +`tok2vec` or `transformer` pipeline components that store tensors and other +values that can require a lot of memory and frequently aren't needed after the +whole pipeline has run. + +> #### Example +> +> ```python +> config = {"attrs": {"tensor": None}} +> nlp.add_pipe("doc_cleaner", config=config) +> doc = nlp("text") +> assert doc.tensor is None +> ``` + +| Setting | Description | +| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `attrs` | A dict of the `Doc` attributes and the values to set them to. Defaults to `{"tensor": None, "_.trf_data": None}` to clean up after `tok2vec` and `transformer` components. ~~dict~~ | +| `silent` | If `False`, show warnings if attributes aren't found or can't be set. Defaults to `True`. ~~bool~~ | +| **RETURNS** | The modified `Doc` with the modified attributes. ~~Doc~~ |