Fix attributeruler

This commit is contained in:
Matthew Honnibal 2020-09-26 00:19:53 +02:00
parent 98327f66a9
commit 821f37254c
1 changed files with 16 additions and 6 deletions

View File

@ -80,11 +80,14 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#call DOCS: https://nightly.spacy.io/api/attributeruler#call
""" """
matches = sorted(self.matcher(doc, allow_missing=True)) matches = sorted(self.matcher(doc, allow_missing=True))
print("Attrs", self.attrs)
print("Matches", matches)
for match_id, start, end in matches: for match_id, start, end in matches:
span = Span(doc, start, end, label=match_id) span = Span(doc, start, end, label=match_id)
attrs = self.attrs[span.label] attr_id = _parse_key(span.label_)
index = self.indices[span.label] attrs = self.attrs[attr_id]
index = self.indices[attr_id]
try: try:
token = span[index] token = span[index]
except IndexError: except IndexError:
@ -173,9 +176,10 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#add DOCS: https://nightly.spacy.io/api/attributeruler#add
""" """
# This needs to be a string, because otherwise it's interpreted as a # We need to make a string here, because otherwise the ID we pass back
# string key. # will be interpreted as the hash of a string, rather than an ordinal.
self.matcher.add(f"attr_rules_{len(self.attrs)}", patterns) key = _make_key(len(self.attrs))
self.matcher.add(self.vocab.strings.add(key), patterns)
self._attrs_unnormed.append(attrs) self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs) attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs) self.attrs.append(attrs)
@ -199,7 +203,7 @@ class AttributeRuler(Pipe):
all_patterns = [] all_patterns = []
for i in range(len(self.attrs)): for i in range(len(self.attrs)):
p = {} p = {}
p["patterns"] = self.matcher.get(i)[1] p["patterns"] = self.matcher.get(_make_key(i))[1]
p["attrs"] = self._attrs_unnormed[i] p["attrs"] = self._attrs_unnormed[i]
p["index"] = self.indices[i] p["index"] = self.indices[i]
all_patterns.append(p) all_patterns.append(p)
@ -303,6 +307,12 @@ class AttributeRuler(Pipe):
return self return self
def _make_key(n_attr):
return f"attr_rule_{n_attr}"
def _parse_key(key):
return int(key.rsplit("_", 1)[1])
def _split_morph_attrs(attrs): def _split_morph_attrs(attrs):
"""Split entries from a tag map or morph rules dict into to two dicts, one """Split entries from a tag map or morph rules dict into to two dicts, one