mirror of https://github.com/explosion/spaCy.git
Add doc_cleaner component (#9659)
* Add doc_cleaner component * Fix types * Fix loop * Rephrase method description
This commit is contained in:
parent
a77f50baa4
commit
9ac6d4991e
|
@ -191,6 +191,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
"lead to errors.")
|
||||
W115 = ("Skipping {method}: the floret vector table cannot be modified. "
|
||||
"Vectors are calculated from character ngrams.")
|
||||
W116 = ("Unable to clean attribute '{attr}'.")
|
||||
|
||||
|
||||
class Errors(metaclass=ErrorsWithCodes):
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Dict, Any
|
||||
import srsly
|
||||
import warnings
|
||||
|
||||
from ..errors import Warnings
|
||||
from ..language import Language
|
||||
from ..matcher import Matcher
|
||||
from ..tokens import Doc
|
||||
|
@ -136,3 +138,65 @@ class TokenSplitter:
|
|||
"cfg": lambda p: self._set_config(srsly.read_json(p)),
|
||||
}
|
||||
util.from_disk(path, serializers, [])
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"doc_cleaner",
|
||||
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
|
||||
)
|
||||
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
|
||||
return DocCleaner(attrs, silent=silent)
|
||||
|
||||
|
||||
class DocCleaner:
|
||||
def __init__(self, attrs: Dict[str, Any], *, silent: bool = True):
|
||||
self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent}
|
||||
|
||||
def __call__(self, doc: Doc) -> Doc:
|
||||
attrs: dict = self.cfg["attrs"]
|
||||
silent: bool = self.cfg["silent"]
|
||||
for attr, value in attrs.items():
|
||||
obj = doc
|
||||
parts = attr.split(".")
|
||||
skip = False
|
||||
for part in parts[:-1]:
|
||||
if hasattr(obj, part):
|
||||
obj = getattr(obj, part)
|
||||
else:
|
||||
skip = True
|
||||
if not silent:
|
||||
warnings.warn(Warnings.W116.format(attr=attr))
|
||||
if not skip:
|
||||
if hasattr(obj, parts[-1]):
|
||||
setattr(obj, parts[-1], value)
|
||||
else:
|
||||
if not silent:
|
||||
warnings.warn(Warnings.W116.format(attr=attr))
|
||||
return doc
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
serializers = {
|
||||
"cfg": lambda: srsly.json_dumps(self.cfg),
|
||||
}
|
||||
return util.to_bytes(serializers, [])
|
||||
|
||||
def from_bytes(self, data, **kwargs):
|
||||
deserializers = {
|
||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||
}
|
||||
util.from_bytes(data, deserializers, [])
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
path = util.ensure_path(path)
|
||||
serializers = {
|
||||
"cfg": lambda p: srsly.write_json(p, self.cfg),
|
||||
}
|
||||
return util.to_disk(path, serializers, [])
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
path = util.ensure_path(path)
|
||||
serializers = {
|
||||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
}
|
||||
util.from_disk(path, serializers, [])
|
||||
|
|
|
@ -3,6 +3,8 @@ from spacy.pipeline.functions import merge_subtokens
|
|||
from spacy.language import Language
|
||||
from spacy.tokens import Span, Doc
|
||||
|
||||
from ..doc.test_underscore import clean_underscore # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def doc(en_vocab):
|
||||
|
@ -74,3 +76,26 @@ def test_token_splitter():
|
|||
"i",
|
||||
]
|
||||
assert all(len(t.text) <= token_splitter.split_length for t in doc)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clean_underscore")
|
||||
def test_factories_doc_cleaner():
|
||||
nlp = Language()
|
||||
nlp.add_pipe("doc_cleaner")
|
||||
doc = nlp.make_doc("text")
|
||||
doc.tensor = [1, 2, 3]
|
||||
doc = nlp(doc)
|
||||
assert doc.tensor is None
|
||||
|
||||
nlp = Language()
|
||||
nlp.add_pipe("doc_cleaner", config={"silent": False})
|
||||
with pytest.warns(UserWarning):
|
||||
doc = nlp("text")
|
||||
|
||||
Doc.set_extension("test_attr", default=-1)
|
||||
nlp = Language()
|
||||
nlp.add_pipe("doc_cleaner", config={"attrs": {"_.test_attr": 0}})
|
||||
doc = nlp.make_doc("text")
|
||||
doc._.test_attr = 100
|
||||
doc = nlp(doc)
|
||||
assert doc._.test_attr == 0
|
||||
|
|
|
@ -130,3 +130,25 @@ exceed the transformer model max length.
|
|||
| `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ |
|
||||
| `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ |
|
||||
| **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ |
|
||||
|
||||
## doc_cleaner {#doc_cleaner tag="function" new="3.2.1"}
|
||||
|
||||
Clean up `Doc` attributes. Intended for use at the end of pipelines with
|
||||
`tok2vec` or `transformer` pipeline components that store tensors and other
|
||||
values that can require a lot of memory and frequently aren't needed after the
|
||||
whole pipeline has run.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> config = {"attrs": {"tensor": None}}
|
||||
> nlp.add_pipe("doc_cleaner", config=config)
|
||||
> doc = nlp("text")
|
||||
> assert doc.tensor is None
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `attrs` | A dict of the `Doc` attributes and the values to set them to. Defaults to `{"tensor": None, "_.trf_data": None}` to clean up after `tok2vec` and `transformer` components. ~~dict~~ |
|
||||
| `silent` | If `False`, show warnings if attributes aren't found or can't be set. Defaults to `True`. ~~bool~~ |
|
||||
| **RETURNS** | The modified `Doc` with the modified attributes. ~~Doc~~ |
|
||||
|
|
Loading…
Reference in New Issue