mirror of https://github.com/explosion/spaCy.git
Added an argument to `EntityRuler` constructor to pass attrs to… (#3919)
* Perserve flags in EntityRuler The EntityRuler (explosion/spaCy#3526) does not preserve overwrite flags (or `ent_id_sep`) when serialized. This commit adds support for serialization/deserialization preserving overwrite and ent_id_sep flags. * add signed contributor agreement * flake8 cleanup mostly blank line issues. * mark test from the issue as needing a model The test from the issue needs some language model for serialization but the test wasn't originally marked correctly. * Adds `phrase_matcher_attr` to allow args to PhraseMatcher This is an added arg to pass to the `PhraseMatcher`. For example, this allows creation of a case insensitive phrase matcher when the `EntityRuler` is created. References explosion/spaCy#3822 * remove unneeded model loading The model didn't need to be loaded, and I replaced it with a change that doesn't require it (using existings fixtures) * updated docstring for new argument * updated docs to reflect new argument to the EntityRuler constructor * change tempdir handling to be compatible with python 2.7 * return conflicted code to entityruler Some stuff got cut out because of merge conflicts, this returns that code for the phrase_matcher_attr. * fixed typo in the code added back after conflicts * flake8 compliance When I deconflicted the branch there were some flake8 issues introduced. This resolves the spacing problems. * test changes: attempts to fix flaky test in python3.5 These tests seem to be alittle flaky in 3.5 so I changed the check to avoid the comparisons that seem to be fail sometimes.
This commit is contained in:
parent
a795fbd3b2
commit
2eb925bd05
|
@ -26,7 +26,7 @@ class EntityRuler(object):
|
|||
|
||||
name = "entity_ruler"
|
||||
|
||||
def __init__(self, nlp, **cfg):
|
||||
def __init__(self, nlp, phrase_matcher_attr=None, **cfg):
|
||||
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||
|
@ -34,6 +34,8 @@ class EntityRuler(object):
|
|||
|
||||
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
||||
and process phrase patterns.
|
||||
phrase_matcher_attr (int / unicode): Token attribute to match on, passed
|
||||
to the internal PhraseMatcher as `attr`
|
||||
patterns (iterable): Optional patterns to load in.
|
||||
overwrite_ents (bool): If existing entities are present, e.g. entities
|
||||
added by the model, overwrite them by matches if necessary.
|
||||
|
@ -49,7 +51,12 @@ class EntityRuler(object):
|
|||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self.matcher = Matcher(nlp.vocab)
|
||||
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
||||
if phrase_matcher_attr is not None:
|
||||
self.phrase_matcher_attr = phrase_matcher_attr
|
||||
self.phrase_matcher = PhraseMatcher(nlp.vocab, attr=self.phrase_matcher_attr)
|
||||
else:
|
||||
self.phrase_matcher_attr = None
|
||||
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
patterns = cfg.get("patterns")
|
||||
if patterns is not None:
|
||||
|
@ -218,6 +225,10 @@ class EntityRuler(object):
|
|||
if isinstance(cfg, dict):
|
||||
self.add_patterns(cfg.get('patterns', cfg))
|
||||
self.overwrite = cfg.get('overwrite', False)
|
||||
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr', None)
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
|
||||
attr=self.phrase_matcher_attr)
|
||||
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
||||
else:
|
||||
self.add_patterns(cfg)
|
||||
|
@ -234,6 +245,7 @@ class EntityRuler(object):
|
|||
serial = OrderedDict((
|
||||
('overwrite', self.overwrite),
|
||||
('ent_id_sep', self.ent_id_sep),
|
||||
('phrase_matcher_attr', self.phrase_matcher_attr),
|
||||
('patterns', self.patterns)))
|
||||
return srsly.msgpack_dumps(serial)
|
||||
|
||||
|
@ -259,7 +271,12 @@ class EntityRuler(object):
|
|||
}
|
||||
from_disk(path, deserializers, {})
|
||||
self.overwrite = cfg.get('overwrite', False)
|
||||
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr')
|
||||
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
||||
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
|
||||
attr=self.phrase_matcher_attr)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
|
@ -273,6 +290,7 @@ class EntityRuler(object):
|
|||
DOCS: https://spacy.io/api/entityruler#to_disk
|
||||
"""
|
||||
cfg = {'overwrite': self.overwrite,
|
||||
'phrase_matcher_attr': self.phrase_matcher_attr,
|
||||
'ent_id_sep': self.ent_id_sep}
|
||||
serializers = {
|
||||
'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'),
|
||||
|
|
|
@ -106,5 +106,24 @@ def test_entity_ruler_serialize_bytes(nlp, patterns):
|
|||
assert len(new_ruler) == 0
|
||||
assert len(new_ruler.labels) == 0
|
||||
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||
assert len(new_ruler) == len(patterns)
|
||||
assert len(new_ruler.labels) == 4
|
||||
assert len(new_ruler.patterns) == len(ruler.patterns)
|
||||
for pattern in ruler.patterns:
|
||||
assert pattern in new_ruler.patterns
|
||||
assert new_ruler.labels == ruler.labels
|
||||
|
||||
|
||||
def test_entity_ruler_serialize_phrase_matcher_attr_bytes(nlp, patterns):
|
||||
ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER", patterns=patterns)
|
||||
assert len(ruler) == len(patterns)
|
||||
assert len(ruler.labels) == 4
|
||||
ruler_bytes = ruler.to_bytes()
|
||||
new_ruler = EntityRuler(nlp)
|
||||
assert len(new_ruler) == 0
|
||||
assert len(new_ruler.labels) == 0
|
||||
assert new_ruler.phrase_matcher_attr is None
|
||||
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||
assert len(new_ruler) == len(patterns)
|
||||
assert len(new_ruler.labels) == 4
|
||||
assert new_ruler.phrase_matcher_attr == "LOWER"
|
||||
|
|
|
@ -9,6 +9,7 @@ from spacy import load
|
|||
import srsly
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patterns():
|
||||
return [
|
||||
|
@ -28,6 +29,7 @@ def add_ent():
|
|||
|
||||
return add_ent_component
|
||||
|
||||
|
||||
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
|
||||
nlp = Language(vocab=en_vocab)
|
||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||
|
@ -50,7 +52,8 @@ def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab):
|
|||
new_ruler = EntityRuler(nlp)
|
||||
new_ruler = new_ruler.from_bytes(bytes_old_style)
|
||||
assert len(new_ruler) == len(ruler)
|
||||
assert new_ruler.patterns == ruler.patterns
|
||||
for pattern in ruler.patterns:
|
||||
assert pattern in new_ruler.patterns
|
||||
assert new_ruler.overwrite is not ruler.overwrite
|
||||
|
||||
|
||||
|
@ -62,7 +65,8 @@ def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab):
|
|||
srsly.write_jsonl(out_file, ruler.patterns)
|
||||
new_ruler = EntityRuler(nlp)
|
||||
new_ruler = new_ruler.from_disk(out_file)
|
||||
assert new_ruler.patterns == ruler.patterns
|
||||
for pattern in ruler.patterns:
|
||||
assert pattern in new_ruler.patterns
|
||||
assert len(new_ruler) == len(ruler)
|
||||
assert new_ruler.overwrite is not ruler.overwrite
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ be a token pattern (list) or a phrase pattern (string). For example:
|
|||
| ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
|
||||
| `patterns` | iterable | Optional patterns to load in. |
|
||||
| `phrase_matcher_attr` | int / unicode | Optional attr to pass to the internal [`PhraseMatcher`](/api/phtasematcher). defaults to `None`
|
||||
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
|
||||
| `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. |
|
||||
| **RETURNS** | `EntityRuler` | The newly constructed object. |
|
||||
|
|
Loading…
Reference in New Issue