diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index edf6b722b..4f89e4186 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -26,7 +26,7 @@ class EntityRuler(object): name = "entity_ruler" - def __init__(self, nlp, **cfg): + def __init__(self, nlp, phrase_matcher_attr=None, **cfg): """Initialize the entitiy ruler. If patterns are supplied here, they need to be a list of dictionaries with a `"label"` and `"pattern"` key. A pattern can either be a token pattern (list) or a phrase pattern @@ -34,6 +34,8 @@ class EntityRuler(object): nlp (Language): The shared nlp object to pass the vocab to the matchers and process phrase patterns. + phrase_matcher_attr (int / unicode): Token attribute to match on, passed + to the internal PhraseMatcher as `attr` patterns (iterable): Optional patterns to load in. overwrite_ents (bool): If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. @@ -49,7 +51,12 @@ class EntityRuler(object): self.token_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list) self.matcher = Matcher(nlp.vocab) - self.phrase_matcher = PhraseMatcher(nlp.vocab) + if phrase_matcher_attr is not None: + self.phrase_matcher_attr = phrase_matcher_attr + self.phrase_matcher = PhraseMatcher(nlp.vocab, attr=self.phrase_matcher_attr) + else: + self.phrase_matcher_attr = None + self.phrase_matcher = PhraseMatcher(nlp.vocab) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) patterns = cfg.get("patterns") if patterns is not None: @@ -218,6 +225,10 @@ class EntityRuler(object): if isinstance(cfg, dict): self.add_patterns(cfg.get('patterns', cfg)) self.overwrite = cfg.get('overwrite', False) + self.phrase_matcher_attr = cfg.get('phrase_matcher_attr', None) + if self.phrase_matcher_attr is not None: + self.phrase_matcher = PhraseMatcher(self.nlp.vocab, + attr=self.phrase_matcher_attr) self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) else: self.add_patterns(cfg) @@ -234,6 +245,7 @@ class EntityRuler(object): serial = OrderedDict(( ('overwrite', self.overwrite), ('ent_id_sep', self.ent_id_sep), + ('phrase_matcher_attr', self.phrase_matcher_attr), ('patterns', self.patterns))) return srsly.msgpack_dumps(serial) @@ -259,7 +271,12 @@ class EntityRuler(object): } from_disk(path, deserializers, {}) self.overwrite = cfg.get('overwrite', False) + self.phrase_matcher_attr = cfg.get('phrase_matcher_attr') self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) + + if self.phrase_matcher_attr is not None: + self.phrase_matcher = PhraseMatcher(self.nlp.vocab, + attr=self.phrase_matcher_attr) return self def to_disk(self, path, **kwargs): @@ -273,6 +290,7 @@ class EntityRuler(object): DOCS: https://spacy.io/api/entityruler#to_disk """ cfg = {'overwrite': self.overwrite, + 'phrase_matcher_attr': self.phrase_matcher_attr, 'ent_id_sep': self.ent_id_sep} serializers = { 'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'), diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 040d5ff22..a371be38b 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -106,5 +106,24 @@ def test_entity_ruler_serialize_bytes(nlp, patterns): assert len(new_ruler) == 0 assert len(new_ruler.labels) == 0 new_ruler = new_ruler.from_bytes(ruler_bytes) + assert len(new_ruler) == len(patterns) + assert len(new_ruler.labels) == 4 + assert len(new_ruler.patterns) == len(ruler.patterns) + for pattern in ruler.patterns: + assert pattern in new_ruler.patterns + assert new_ruler.labels == ruler.labels + + +def test_entity_ruler_serialize_phrase_matcher_attr_bytes(nlp, patterns): + ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER", patterns=patterns) assert len(ruler) == len(patterns) assert len(ruler.labels) == 4 + ruler_bytes = ruler.to_bytes() + new_ruler = EntityRuler(nlp) + assert len(new_ruler) == 0 + assert len(new_ruler.labels) == 0 + assert new_ruler.phrase_matcher_attr is None + new_ruler = new_ruler.from_bytes(ruler_bytes) + assert len(new_ruler) == len(patterns) + assert len(new_ruler.labels) == 4 + assert new_ruler.phrase_matcher_attr == "LOWER" diff --git a/spacy/tests/regression/test_issue3526.py b/spacy/tests/regression/test_issue3526.py index 118cb3af5..3949c4b1c 100644 --- a/spacy/tests/regression/test_issue3526.py +++ b/spacy/tests/regression/test_issue3526.py @@ -9,6 +9,7 @@ from spacy import load import srsly from ..util import make_tempdir + @pytest.fixture def patterns(): return [ @@ -28,6 +29,7 @@ def add_ent(): return add_ent_component + def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab): nlp = Language(vocab=en_vocab) ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) @@ -50,7 +52,8 @@ def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab): new_ruler = EntityRuler(nlp) new_ruler = new_ruler.from_bytes(bytes_old_style) assert len(new_ruler) == len(ruler) - assert new_ruler.patterns == ruler.patterns + for pattern in ruler.patterns: + assert pattern in new_ruler.patterns assert new_ruler.overwrite is not ruler.overwrite @@ -62,7 +65,8 @@ def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab): srsly.write_jsonl(out_file, ruler.patterns) new_ruler = EntityRuler(nlp) new_ruler = new_ruler.from_disk(out_file) - assert new_ruler.patterns == ruler.patterns + for pattern in ruler.patterns: + assert pattern in new_ruler.patterns assert len(new_ruler) == len(ruler) assert new_ruler.overwrite is not ruler.overwrite diff --git a/website/docs/api/entityruler.md b/website/docs/api/entityruler.md index 45c4756f2..dcbf99da5 100644 --- a/website/docs/api/entityruler.md +++ b/website/docs/api/entityruler.md @@ -34,6 +34,7 @@ be a token pattern (list) or a phrase pattern (string). For example: | ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- | | `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. | | `patterns` | iterable | Optional patterns to load in. | +| `phrase_matcher_attr` | int / unicode | Optional attr to pass to the internal [`PhraseMatcher`](/api/phtasematcher). defaults to `None` | `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. | | `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. | | **RETURNS** | `EntityRuler` | The newly constructed object. |