diff --git a/spacy/errors.py b/spacy/errors.py index 418d682ad..4ae43a497 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -610,6 +610,9 @@ class Errors: "initializing the pipeline:\n" 'cfg = {"tokenizer": {"segmenter": "pkuseg", "pkuseg_model": name_or_path}}\n' 'nlp = Chinese(config=cfg)') + E1001 = ("Target token outside of matched span for match with tokens " + "'{span}' and offset '{index}' matched by patterns '{patterns}'.") + E1002 = ("Span index out of range.") @add_codes diff --git a/spacy/pipeline/__init__.py b/spacy/pipeline/__init__.py index f8accd14f..7f395b5f2 100644 --- a/spacy/pipeline/__init__.py +++ b/spacy/pipeline/__init__.py @@ -1,3 +1,4 @@ +from .attributeruler import AttributeRuler from .dep_parser import DependencyParser from .entity_linker import EntityLinker from .ner import EntityRecognizer @@ -13,6 +14,7 @@ from .tok2vec import Tok2Vec from .functions import merge_entities, merge_noun_chunks, merge_subtokens __all__ = [ + "AttributeRuler", "DependencyParser", "EntityLinker", "EntityRecognizer", diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py new file mode 100644 index 000000000..ac86f60e0 --- /dev/null +++ b/spacy/pipeline/attributeruler.py @@ -0,0 +1,268 @@ +import srsly +from typing import List, Dict, Union, Iterable, Any, Optional +from pathlib import Path + +from .pipe import Pipe +from ..errors import Errors +from ..language import Language +from ..matcher import Matcher +from ..symbols import IDS +from ..tokens import Doc, Span +from ..tokens._retokenize import normalize_token_attrs, set_token_attrs +from ..vocab import Vocab +from .. import util + + +MatcherPatternType = List[Dict[Union[int, str], Any]] +AttributeRulerPatternType = Dict[str, Union[MatcherPatternType, Dict, int]] + + +@Language.factory( + "attribute_ruler", +) +def make_attribute_ruler( + nlp: Language, + name: str, + pattern_dicts: Optional[Iterable[AttributeRulerPatternType]] = None, +): + return AttributeRuler(nlp.vocab, name, pattern_dicts=pattern_dicts) + + +class AttributeRuler(Pipe): + """Set token-level attributes for tokens matched by Matcher patterns. + Additionally supports importing patterns from tag maps and morph rules. + + DOCS: https://spacy.io/api/attributeruler + """ + + def __init__( + self, + vocab: Vocab, + name: str = "attribute_ruler", + *, + pattern_dicts: Optional[Iterable[AttributeRulerPatternType]] = None, + ) -> None: + """Initialize the AttributeRuler. + + vocab (Vocab): The vocab. + name (str): The pipe name. Defaults to "attribute_ruler". + pattern_dicts (Iterable[Dict]): A list of pattern dicts with the keys as + the arguments to AttributeRuler.add (`patterns`/`attrs`/`index`) to add + as patterns. + + RETURNS (AttributeRuler): The AttributeRuler component. + + DOCS: https://spacy.io/api/attributeruler#init + """ + self.name = name + self.vocab = vocab + self.matcher = Matcher(self.vocab) + self.attrs = [] + self._attrs_unnormed = [] # store for reference + self.indices = [] + + if pattern_dicts: + self.add_patterns(pattern_dicts) + + def __call__(self, doc: Doc) -> Doc: + """Apply the attributeruler to a Doc and set all attribute exceptions. + + doc (Doc): The document to process. + RETURNS (Doc): The processed Doc. + + DOCS: https://spacy.io/api/attributeruler#call + """ + matches = self.matcher(doc) + + for match_id, start, end in matches: + span = Span(doc, start, end, label=match_id) + attrs = self.attrs[span.label] + index = self.indices[span.label] + try: + token = span[index] + except IndexError: + raise ValueError( + Errors.E1001.format( + patterns=self.matcher.get(span.label), + span=[t.text for t in span], + index=index, + ) + ) + set_token_attrs(token, attrs) + return doc + + def load_from_tag_map( + self, tag_map: Dict[str, Dict[Union[int, str], Union[int, str]]] + ) -> None: + for tag, attrs in tag_map.items(): + pattern = [{"TAG": tag}] + attrs, morph_attrs = _split_morph_attrs(attrs) + morph = self.vocab.morphology.add(morph_attrs) + attrs["MORPH"] = self.vocab.strings[morph] + self.add([pattern], attrs) + + def load_from_morph_rules( + self, morph_rules: Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]] + ) -> None: + for tag in morph_rules: + for word in morph_rules[tag]: + pattern = [{"ORTH": word, "TAG": tag}] + attrs = morph_rules[tag][word] + attrs, morph_attrs = _split_morph_attrs(attrs) + morph = self.vocab.morphology.add(morph_attrs) + attrs["MORPH"] = self.vocab.strings[morph] + self.add([pattern], attrs) + + def add( + self, patterns: Iterable[MatcherPatternType], attrs: Dict, index: int = 0 + ) -> None: + """Add Matcher patterns for tokens that should be modified with the + provided attributes. The token at the specified index within the + matched span will be assigned the attributes. + + patterns (Iterable[List[Dict]]): A list of Matcher patterns. + attrs (Dict): The attributes to assign to the target token in the + matched span. + index (int): The index of the token in the matched span to modify. May + be negative to index from the end of the span. Defaults to 0. + + DOCS: https://spacy.io/api/attributeruler#add + """ + self.matcher.add(len(self.attrs), patterns) + self._attrs_unnormed.append(attrs) + attrs = normalize_token_attrs(self.vocab, attrs) + self.attrs.append(attrs) + self.indices.append(index) + + def add_patterns(self, pattern_dicts: Iterable[AttributeRulerPatternType]) -> None: + for p in pattern_dicts: + self.add(**p) + + @property + def patterns(self) -> List[AttributeRulerPatternType]: + all_patterns = [] + for i in range(len(self.attrs)): + p = {} + p["patterns"] = self.matcher.get(i)[1] + p["attrs"] = self._attrs_unnormed[i] + p["index"] = self.indices[i] + all_patterns.append(p) + return all_patterns + + def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes: + """Serialize the attributeruler to a bytestring. + + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (bytes): The serialized object. + + DOCS: https://spacy.io/api/attributeruler#to_bytes + """ + serialize = {} + serialize["vocab"] = self.vocab.to_bytes + patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} + serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns) + serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs) + serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices) + return util.to_bytes(serialize, exclude) + + def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): + """Load the attributeruler from a bytestring. + + bytes_data (bytes): The data to load. + exclude (Iterable[str]): String names of serialization fields to exclude. + returns (AttributeRuler): The loaded object. + + DOCS: https://spacy.io/api/attributeruler#from_bytes + """ + data = {"patterns": b""} + + def load_patterns(b): + data["patterns"] = srsly.msgpack_loads(b) + + def load_attrs(b): + self.attrs = srsly.msgpack_loads(b) + + def load_indices(b): + self.indices = srsly.msgpack_loads(b) + + deserialize = { + "vocab": lambda b: self.vocab.from_bytes(b), + "patterns": load_patterns, + "attrs": load_attrs, + "indices": load_indices, + } + util.from_bytes(bytes_data, deserialize, exclude) + + if data["patterns"]: + for key, pattern in data["patterns"].items(): + self.matcher.add(key, pattern) + assert len(self.attrs) == len(data["patterns"]) + assert len(self.indices) == len(data["patterns"]) + + return self + + def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: + """Serialize the attributeruler to disk. + + path (Union[Path, str]): A path to a directory. + exclude (Iterable[str]): String names of serialization fields to exclude. + DOCS: https://spacy.io/api/attributeruler#to_disk + """ + patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} + serialize = { + "vocab": lambda p: self.vocab.to_disk(p), + "patterns": lambda p: srsly.write_msgpack(p, patterns), + "attrs": lambda p: srsly.write_msgpack(p, self.attrs), + "indices": lambda p: srsly.write_msgpack(p, self.indices), + } + util.to_disk(path, serialize, exclude) + + def from_disk( + self, path: Union[Path, str], exclude: Iterable[str] = tuple() + ) -> None: + """Load the attributeruler from disk. + + path (Union[Path, str]): A path to a directory. + exclude (Iterable[str]): String names of serialization fields to exclude. + DOCS: https://spacy.io/api/attributeruler#from_disk + """ + data = {"patterns": b""} + + def load_patterns(p): + data["patterns"] = srsly.read_msgpack(p) + + def load_attrs(p): + self.attrs = srsly.read_msgpack(p) + + def load_indices(p): + self.indices = srsly.read_msgpack(p) + + deserialize = { + "vocab": lambda p: self.vocab.from_disk(p), + "patterns": load_patterns, + "attrs": load_attrs, + "indices": load_indices, + } + util.from_disk(path, deserialize, exclude) + + if data["patterns"]: + for key, pattern in data["patterns"].items(): + self.matcher.add(key, pattern) + assert len(self.attrs) == len(data["patterns"]) + assert len(self.indices) == len(data["patterns"]) + + return self + + +def _split_morph_attrs(attrs): + """Split entries from a tag map or morph rules dict into to two dicts, one + with the token-level features (POS, LEMMA) and one with the remaining + features, which are presumed to be individual MORPH features.""" + other_attrs = {} + morph_attrs = {} + for k, v in attrs.items(): + if k in "_" or k in IDS.keys() or k in IDS.values(): + other_attrs[k] = v + else: + morph_attrs[k] = v + return other_attrs, morph_attrs diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 91b0ec922..686678a14 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -282,3 +282,15 @@ def test_span_eq_hash(doc, doc_not_parsed): assert hash(doc[0:2]) == hash(doc[0:2]) assert hash(doc[0:2]) != hash(doc[1:3]) assert hash(doc[0:2]) != hash(doc_not_parsed[0:2]) + + +def test_span_boundaries(doc): + start = 1 + end = 5 + span = doc[start:end] + for i in range(start, end): + assert span[i - start] == doc[i] + with pytest.raises(IndexError): + _ = span[-5] + with pytest.raises(IndexError): + _ = span[5] diff --git a/spacy/tests/pipeline/test_attributeruler.py b/spacy/tests/pipeline/test_attributeruler.py new file mode 100644 index 000000000..a4cf34717 --- /dev/null +++ b/spacy/tests/pipeline/test_attributeruler.py @@ -0,0 +1,208 @@ +import pytest +import numpy +from spacy.lang.en import English +from spacy.pipeline import AttributeRuler +from spacy import util, registry + +from ..util import get_doc, make_tempdir + + +@pytest.fixture +def nlp(): + return English() + + +@pytest.fixture +def pattern_dicts(): + return [ + { + "patterns": [[{"ORTH": "a"}], [{"ORTH": "irrelevant"}]], + "attrs": {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + }, + # one pattern sets the lemma + {"patterns": [[{"ORTH": "test"}]], "attrs": {"LEMMA": "cat"}}, + # another pattern sets the morphology + { + "patterns": [[{"ORTH": "test"}]], + "attrs": {"MORPH": "Case=Nom|Number=Sing"}, + "index": 0, + }, + ] + + +@registry.assets("attribute_ruler_patterns") +def attribute_ruler_patterns(): + return [ + { + "patterns": [[{"ORTH": "a"}], [{"ORTH": "irrelevant"}]], + "attrs": {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + }, + # one pattern sets the lemma + {"patterns": [[{"ORTH": "test"}]], "attrs": {"LEMMA": "cat"}}, + # another pattern sets the morphology + { + "patterns": [[{"ORTH": "test"}]], + "attrs": {"MORPH": "Case=Nom|Number=Sing"}, + "index": 0, + }, + ] + + +@pytest.fixture +def tag_map(): + return { + ".": {"POS": "PUNCT", "PunctType": "peri"}, + ",": {"POS": "PUNCT", "PunctType": "comm"}, + } + + +@pytest.fixture +def morph_rules(): + return {"DT": {"the": {"POS": "DET", "LEMMA": "a", "Case": "Nom"}}} + + +def test_attributeruler_init(nlp, pattern_dicts): + a = nlp.add_pipe("attribute_ruler") + for p in pattern_dicts: + a.add(**p) + + doc = nlp("This is a test.") + assert doc[2].lemma_ == "the" + assert doc[2].morph_ == "Case=Nom|Number=Plur" + assert doc[3].lemma_ == "cat" + assert doc[3].morph_ == "Case=Nom|Number=Sing" + + +def test_attributeruler_init_patterns(nlp, pattern_dicts): + # initialize with patterns + a = nlp.add_pipe("attribute_ruler", config={"pattern_dicts": pattern_dicts}) + + doc = nlp("This is a test.") + assert doc[2].lemma_ == "the" + assert doc[2].morph_ == "Case=Nom|Number=Plur" + assert doc[3].lemma_ == "cat" + assert doc[3].morph_ == "Case=Nom|Number=Sing" + + nlp.remove_pipe("attribute_ruler") + + # initialize with patterns from asset + a = nlp.add_pipe("attribute_ruler", config={"pattern_dicts": {"@assets": "attribute_ruler_patterns"}}) + + doc = nlp("This is a test.") + assert doc[2].lemma_ == "the" + assert doc[2].morph_ == "Case=Nom|Number=Plur" + assert doc[3].lemma_ == "cat" + assert doc[3].morph_ == "Case=Nom|Number=Sing" + + +def test_attributeruler_tag_map(nlp, tag_map): + a = AttributeRuler(nlp.vocab) + a.load_from_tag_map(tag_map) + doc = get_doc( + nlp.vocab, + words=["This", "is", "a", "test", "."], + tags=["DT", "VBZ", "DT", "NN", "."], + ) + doc = a(doc) + + for i in range(len(doc)): + if i == 4: + assert doc[i].pos_ == "PUNCT" + assert doc[i].morph_ == "PunctType=peri" + else: + assert doc[i].pos_ == "" + assert doc[i].morph_ == "" + + +def test_attributeruler_morph_rules(nlp, morph_rules): + a = AttributeRuler(nlp.vocab) + a.load_from_morph_rules(morph_rules) + doc = get_doc( + nlp.vocab, + words=["This", "is", "the", "test", "."], + tags=["DT", "VBZ", "DT", "NN", "."], + ) + doc = a(doc) + + for i in range(len(doc)): + if i != 2: + assert doc[i].pos_ == "" + assert doc[i].morph_ == "" + else: + assert doc[2].pos_ == "DET" + assert doc[2].lemma_ == "a" + assert doc[2].morph_ == "Case=Nom" + + +def test_attributeruler_indices(nlp): + a = nlp.add_pipe("attribute_ruler") + a.add( + [[{"ORTH": "a"}, {"ORTH": "test"}]], + {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + index=0, + ) + a.add( + [[{"ORTH": "This"}, {"ORTH": "is"}]], + {"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"}, + index=1, + ) + a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=-1) + + text = "This is a test." + doc = nlp(text) + + for i in range(len(doc)): + if i == 1: + assert doc[i].lemma_ == "was" + assert doc[i].morph_ == "Case=Nom|Number=Sing" + elif i == 2: + assert doc[i].lemma_ == "the" + assert doc[i].morph_ == "Case=Nom|Number=Plur" + elif i == 3: + assert doc[i].lemma_ == "cat" + else: + assert doc[i].morph_ == "" + + # raises an error when trying to modify a token outside of the match + a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=2) + with pytest.raises(ValueError): + doc = nlp(text) + + # raises an error when trying to modify a token outside of the match + a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=10) + with pytest.raises(ValueError): + doc = nlp(text) + + +def test_attributeruler_patterns_prop(nlp, pattern_dicts): + a = nlp.add_pipe("attribute_ruler") + a.add_patterns(pattern_dicts) + + for p1, p2 in zip(pattern_dicts, a.patterns): + assert p1["patterns"] == p2["patterns"] + assert p1["attrs"] == p2["attrs"] + if p1.get("index"): + assert p1["index"] == p2["index"] + + +def test_attributeruler_serialize(nlp, pattern_dicts): + a = nlp.add_pipe("attribute_ruler") + a.add_patterns(pattern_dicts) + + text = "This is a test." + attrs = ["ORTH", "LEMMA", "MORPH"] + doc = nlp(text) + + # bytes roundtrip + a_reloaded = AttributeRuler(nlp.vocab).from_bytes(a.to_bytes()) + assert a.to_bytes() == a_reloaded.to_bytes() + doc1 = a_reloaded(nlp.make_doc(text)) + numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs)) + + # disk roundtrip + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(text) + assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes() + assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs)) diff --git a/spacy/tokens/_retokenize.pyx b/spacy/tokens/_retokenize.pyx index b89ce3bdd..61f7c3db0 100644 --- a/spacy/tokens/_retokenize.pyx +++ b/spacy/tokens/_retokenize.pyx @@ -12,6 +12,7 @@ from .token cimport Token from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..structs cimport LexemeC, TokenC from ..attrs cimport TAG, MORPH +from ..vocab cimport Vocab from .underscore import is_writable_attr from ..attrs import intify_attrs @@ -57,16 +58,7 @@ cdef class Retokenizer: raise ValueError(Errors.E102.format(token=repr(token))) self.tokens_to_merge.add(token.i) self._spans_to_merge.append((span.start, span.end)) - if "_" in attrs: # Extension attributes - extensions = attrs["_"] - _validate_extensions(extensions) - attrs = {key: value for key, value in attrs.items() if key != "_"} - attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) - attrs["_"] = extensions - else: - attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) - if MORPH in attrs: - self.doc.vocab.morphology.add(self.doc.vocab.strings.as_string(attrs[MORPH])) + attrs = normalize_token_attrs(self.doc.vocab, attrs) self.merges.append((span, attrs)) def split(self, Token token, orths, heads, attrs=SimpleFrozenDict()): @@ -98,9 +90,11 @@ cdef class Retokenizer: # NB: Since we support {"KEY": [value, value]} syntax here, this # will only "intify" the keys, not the values attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) - if MORPH in attrs: - for morph in attrs[MORPH]: - self.doc.vocab.morphology.add(self.doc.vocab.strings.as_string(morph)) + if MORPH in attrs: + for i, morph in enumerate(attrs[MORPH]): + # add and set to normalized value + morph = self.doc.vocab.morphology.add(self.doc.vocab.strings.as_string(morph)) + attrs[MORPH][i] = morph head_offsets = [] for head in heads: if isinstance(head, Token): @@ -224,21 +218,7 @@ def _merge(Doc doc, merges): token.lex = lex # We set trailing space here too token.spacy = doc.c[spans[token_index].end-1].spacy - py_token = span[0] - # Assign attributes - for attr_name, attr_value in attributes.items(): - if attr_name == "_": # Set extension attributes - for ext_attr_key, ext_attr_value in attr_value.items(): - py_token._.set(ext_attr_key, ext_attr_value) - elif attr_name == TAG: - doc.vocab.morphology.assign_tag(token, attr_value) - else: - # Set attributes on both token and lexeme to take care of token - # attribute vs. lexical attribute without having to enumerate - # them. If an attribute name is not valid, set_struct_attr will - # ignore it. - Token.set_struct_attr(token, attr_name, attr_value) - Lexeme.set_struct_attr(lex, attr_name, attr_value) + set_token_attrs(span[0], attributes) # Begin by setting all the head indices to absolute token positions # This is easier to work with for now than the offsets # Before thinking of something simpler, beware the case where a @@ -423,3 +403,40 @@ cdef make_iob_consistent(TokenC* tokens, int length): for i in range(1, length): if tokens[i].ent_iob == 1 and tokens[i - 1].ent_type != tokens[i].ent_type: tokens[i].ent_iob = 3 + + +def normalize_token_attrs(Vocab vocab, attrs): + if "_" in attrs: # Extension attributes + extensions = attrs["_"] + print("EXTENSIONS", extensions) + _validate_extensions(extensions) + attrs = {key: value for key, value in attrs.items() if key != "_"} + attrs = intify_attrs(attrs, strings_map=vocab.strings) + attrs["_"] = extensions + else: + attrs = intify_attrs(attrs, strings_map=vocab.strings) + if MORPH in attrs: + # add and set to normalized value + morph = vocab.morphology.add(vocab.strings.as_string(attrs[MORPH])) + attrs[MORPH] = morph + return attrs + + +def set_token_attrs(Token py_token, attrs): + cdef TokenC* token = py_token.c + cdef const LexemeC* lex = token.lex + cdef Doc doc = py_token.doc + # Assign attributes + for attr_name, attr_value in attrs.items(): + if attr_name == "_": # Set extension attributes + for ext_attr_key, ext_attr_value in attr_value.items(): + py_token._.set(ext_attr_key, ext_attr_value) + elif attr_name == TAG: + doc.vocab.morphology.assign_tag(token, attr_value) + else: + # Set attributes on both token and lexeme to take care of token + # attribute vs. lexical attribute without having to enumerate + # them. If an attribute name is not valid, set_struct_attr will + # ignore it. + Token.set_struct_attr(token, attr_name, attr_value) + Lexeme.set_struct_attr(lex, attr_name, attr_value) diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 5b55d8e88..15e6518d6 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -176,9 +176,13 @@ cdef class Span: return Span(self.doc, start + self.start, end + self.start) else: if i < 0: - return self.doc[self.end + i] + token_i = self.end + i else: - return self.doc[self.start + i] + token_i = self.start + i + if self.start <= token_i < self.end: + return self.doc[token_i] + else: + raise IndexError(Errors.E1002) def __iter__(self): """Iterate over `Token` objects.