Add doc_cleaner component (#9659)

* Add doc_cleaner component

* Fix types

* Fix loop

* Rephrase method description
This commit is contained in:
Adriane Boyd 2021-11-23 15:33:33 +01:00 committed by GitHub
parent a77f50baa4
commit 9ac6d4991e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 0 deletions

View File

@ -191,6 +191,7 @@ class Warnings(metaclass=ErrorsWithCodes):
"lead to errors.")
W115 = ("Skipping {method}: the floret vector table cannot be modified. "
"Vectors are calculated from character ngrams.")
W116 = ("Unable to clean attribute '{attr}'.")
class Errors(metaclass=ErrorsWithCodes):

View File

@ -1,6 +1,8 @@
from typing import Dict, Any
import srsly
import warnings
from ..errors import Warnings
from ..language import Language
from ..matcher import Matcher
from ..tokens import Doc
@ -136,3 +138,65 @@ class TokenSplitter:
"cfg": lambda p: self._set_config(srsly.read_json(p)),
}
util.from_disk(path, serializers, [])
@Language.factory(
"doc_cleaner",
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
)
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
return DocCleaner(attrs, silent=silent)
class DocCleaner:
def __init__(self, attrs: Dict[str, Any], *, silent: bool = True):
self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent}
def __call__(self, doc: Doc) -> Doc:
attrs: dict = self.cfg["attrs"]
silent: bool = self.cfg["silent"]
for attr, value in attrs.items():
obj = doc
parts = attr.split(".")
skip = False
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
skip = True
if not silent:
warnings.warn(Warnings.W116.format(attr=attr))
if not skip:
if hasattr(obj, parts[-1]):
setattr(obj, parts[-1], value)
else:
if not silent:
warnings.warn(Warnings.W116.format(attr=attr))
return doc
def to_bytes(self, **kwargs):
serializers = {
"cfg": lambda: srsly.json_dumps(self.cfg),
}
return util.to_bytes(serializers, [])
def from_bytes(self, data, **kwargs):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
}
util.from_bytes(data, deserializers, [])
return self
def to_disk(self, path, **kwargs):
path = util.ensure_path(path)
serializers = {
"cfg": lambda p: srsly.write_json(p, self.cfg),
}
return util.to_disk(path, serializers, [])
def from_disk(self, path, **kwargs):
path = util.ensure_path(path)
serializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
}
util.from_disk(path, serializers, [])

View File

@ -3,6 +3,8 @@ from spacy.pipeline.functions import merge_subtokens
from spacy.language import Language
from spacy.tokens import Span, Doc
from ..doc.test_underscore import clean_underscore # noqa: F401
@pytest.fixture
def doc(en_vocab):
@ -74,3 +76,26 @@ def test_token_splitter():
"i",
]
assert all(len(t.text) <= token_splitter.split_length for t in doc)
@pytest.mark.usefixtures("clean_underscore")
def test_factories_doc_cleaner():
nlp = Language()
nlp.add_pipe("doc_cleaner")
doc = nlp.make_doc("text")
doc.tensor = [1, 2, 3]
doc = nlp(doc)
assert doc.tensor is None
nlp = Language()
nlp.add_pipe("doc_cleaner", config={"silent": False})
with pytest.warns(UserWarning):
doc = nlp("text")
Doc.set_extension("test_attr", default=-1)
nlp = Language()
nlp.add_pipe("doc_cleaner", config={"attrs": {"_.test_attr": 0}})
doc = nlp.make_doc("text")
doc._.test_attr = 100
doc = nlp(doc)
assert doc._.test_attr == 0

View File

@ -130,3 +130,25 @@ exceed the transformer model max length.
| `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ |
| `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ |
| **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ |
## doc_cleaner {#doc_cleaner tag="function" new="3.2.1"}
Clean up `Doc` attributes. Intended for use at the end of pipelines with
`tok2vec` or `transformer` pipeline components that store tensors and other
values that can require a lot of memory and frequently aren't needed after the
whole pipeline has run.
> #### Example
>
> ```python
> config = {"attrs": {"tensor": None}}
> nlp.add_pipe("doc_cleaner", config=config)
> doc = nlp("text")
> assert doc.tensor is None
> ```
| Setting | Description |
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `attrs` | A dict of the `Doc` attributes and the values to set them to. Defaults to `{"tensor": None, "_.trf_data": None}` to clean up after `tok2vec` and `transformer` components. ~~dict~~ |
| `silent` | If `False`, show warnings if attributes aren't found or can't be set. Defaults to `True`. ~~bool~~ |
| **RETURNS** | The modified `Doc` with the modified attributes. ~~Doc~~ |