From 85778dfcf411dcf2cef305feb661f6fc51abacde Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 28 Mar 2022 11:13:50 +0200 Subject: [PATCH] Add edit tree lemmatizer (#10231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add edit tree lemmatizer Co-authored-by: Daniël de Kok * Hide edit tree lemmatizer labels * Use relative imports * Switch to single quotes in error message * Type annotation fixes Co-authored-by: Sofie Van Landeghem * Reformat edit_tree_lemmatizer with black * EditTreeLemmatizer.predict: take Iterable Co-authored-by: Sofie Van Landeghem * Validate edit trees during deserialization This change also changes the serialized representation. Rather than mirroring the deep C structure, we use a simple flat union of the match and substitution node types. * Move edit_trees to _edit_tree_internals * Fix invalid edit tree format error message * edit_tree_lemmatizer: remove outdated TODO comment * Rename factory name to trainable_lemmatizer * Ignore type instead of casting truths to List[Union[Ints1d, Floats2d, List[int], List[str]]] for thinc v8.0.14 * Switch to Tagger.v2 * Add documentation for EditTreeLemmatizer * docs: Fix 3.2 -> 3.3 somewhere * trainable_lemmatizer documentation fixes * docs: EditTreeLemmatizer is in edit_tree_lemmatizer.py Co-authored-by: Daniël de Kok Co-authored-by: Daniël de Kok Co-authored-by: Sofie Van Landeghem --- setup.py | 1 + spacy/errors.py | 2 + spacy/pipeline/__init__.py | 1 + .../pipeline/_edit_tree_internals/__init__.py | 0 .../_edit_tree_internals/edit_trees.pxd | 93 ++++ .../_edit_tree_internals/edit_trees.pyx | 305 +++++++++++++ .../pipeline/_edit_tree_internals/schemas.py | 44 ++ spacy/pipeline/edit_tree_lemmatizer.py | 379 ++++++++++++++++ .../pipeline/test_edit_tree_lemmatizer.py | 280 ++++++++++++ website/docs/api/edittreelemmatizer.md | 409 ++++++++++++++++++ website/docs/api/lemmatizer.md | 7 +- website/docs/usage/101/_architecture.md | 3 +- website/docs/usage/linguistic-features.md | 31 +- website/docs/usage/processing-pipelines.md | 33 +- website/meta/sidebars.json | 1 + 15 files changed, 1562 insertions(+), 27 deletions(-) create mode 100644 spacy/pipeline/_edit_tree_internals/__init__.py create mode 100644 spacy/pipeline/_edit_tree_internals/edit_trees.pxd create mode 100644 spacy/pipeline/_edit_tree_internals/edit_trees.pyx create mode 100644 spacy/pipeline/_edit_tree_internals/schemas.py create mode 100644 spacy/pipeline/edit_tree_lemmatizer.py create mode 100644 spacy/tests/pipeline/test_edit_tree_lemmatizer.py create mode 100644 website/docs/api/edittreelemmatizer.md diff --git a/setup.py b/setup.py index fcc124a43..a5748e9b4 100755 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ MOD_NAMES = [ "spacy.ml.parser_model", "spacy.morphology", "spacy.pipeline.dep_parser", + "spacy.pipeline._edit_tree_internals.edit_trees", "spacy.pipeline.morphologizer", "spacy.pipeline.multitask", "spacy.pipeline.ner", diff --git a/spacy/errors.py b/spacy/errors.py index fe37351f7..8980ca3c3 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -524,6 +524,7 @@ class Errors(metaclass=ErrorsWithCodes): E202 = ("Unsupported {name} mode '{mode}'. Supported modes: {modes}.") # New errors added in v3.x + E857 = ("Entry '{name}' not found in edit tree lemmatizer labels.") E858 = ("The {mode} vector table does not support this operation. " "{alternative}") E859 = ("The floret vector table cannot be modified.") @@ -895,6 +896,7 @@ class Errors(metaclass=ErrorsWithCodes): "patterns.") E1025 = ("Cannot intify the value '{value}' as an IOB string. The only " "supported values are: 'I', 'O', 'B' and ''") + E1026 = ("Edit tree has an invalid format:\n{errors}") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/__init__.py b/spacy/pipeline/__init__.py index 7b483724c..938ab08c6 100644 --- a/spacy/pipeline/__init__.py +++ b/spacy/pipeline/__init__.py @@ -1,5 +1,6 @@ from .attributeruler import AttributeRuler from .dep_parser import DependencyParser +from .edit_tree_lemmatizer import EditTreeLemmatizer from .entity_linker import EntityLinker from .ner import EntityRecognizer from .entityruler import EntityRuler diff --git a/spacy/pipeline/_edit_tree_internals/__init__.py b/spacy/pipeline/_edit_tree_internals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spacy/pipeline/_edit_tree_internals/edit_trees.pxd b/spacy/pipeline/_edit_tree_internals/edit_trees.pxd new file mode 100644 index 000000000..dc4289f37 --- /dev/null +++ b/spacy/pipeline/_edit_tree_internals/edit_trees.pxd @@ -0,0 +1,93 @@ +from libc.stdint cimport uint32_t, uint64_t +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector + +from ...typedefs cimport attr_t, hash_t, len_t +from ...strings cimport StringStore + +cdef extern from "" namespace "std" nogil: + void swap[T](T& a, T& b) except + # Only available in Cython 3. + +# An edit tree (Müller et al., 2015) is a tree structure that consists of +# edit operations. The two types of operations are string matches +# and string substitutions. Given an input string s and an output string t, +# subsitution and match nodes should be interpreted as follows: +# +# * Substitution node: consists of an original string and substitute string. +# If s matches the original string, then t is the substitute. Otherwise, +# the node does not apply. +# * Match node: consists of a prefix length, suffix length, prefix edit tree, +# and suffix edit tree. If s is composed of a prefix, middle part, and suffix +# with the given suffix and prefix lengths, then t is the concatenation +# prefix_tree(prefix) + middle + suffix_tree(suffix). +# +# For efficiency, we represent strings in substitution nodes as integers, with +# the actual strings stored in a StringStore. Subtrees in match nodes are stored +# as tree identifiers (rather than pointers) to simplify serialization. + +cdef uint32_t NULL_TREE_ID + +cdef struct MatchNodeC: + len_t prefix_len + len_t suffix_len + uint32_t prefix_tree + uint32_t suffix_tree + +cdef struct SubstNodeC: + attr_t orig + attr_t subst + +cdef union NodeC: + MatchNodeC match_node + SubstNodeC subst_node + +cdef struct EditTreeC: + bint is_match_node + NodeC inner + +cdef inline EditTreeC edittree_new_match(len_t prefix_len, len_t suffix_len, + uint32_t prefix_tree, uint32_t suffix_tree): + cdef MatchNodeC match_node = MatchNodeC(prefix_len=prefix_len, + suffix_len=suffix_len, prefix_tree=prefix_tree, + suffix_tree=suffix_tree) + cdef NodeC inner = NodeC(match_node=match_node) + return EditTreeC(is_match_node=True, inner=inner) + +cdef inline EditTreeC edittree_new_subst(attr_t orig, attr_t subst): + cdef EditTreeC node + cdef SubstNodeC subst_node = SubstNodeC(orig=orig, subst=subst) + cdef NodeC inner = NodeC(subst_node=subst_node) + return EditTreeC(is_match_node=False, inner=inner) + +cdef inline uint64_t edittree_hash(EditTreeC tree): + cdef MatchNodeC match_node + cdef SubstNodeC subst_node + + if tree.is_match_node: + match_node = tree.inner.match_node + return hash((match_node.prefix_len, match_node.suffix_len, match_node.prefix_tree, match_node.suffix_tree)) + else: + subst_node = tree.inner.subst_node + return hash((subst_node.orig, subst_node.subst)) + +cdef struct LCS: + int source_begin + int source_end + int target_begin + int target_end + +cdef inline bint lcs_is_empty(LCS lcs): + return lcs.source_begin == 0 and lcs.source_end == 0 and lcs.target_begin == 0 and lcs.target_end == 0 + +cdef class EditTrees: + cdef vector[EditTreeC] trees + cdef unordered_map[hash_t, uint32_t] map + cdef StringStore strings + + cpdef uint32_t add(self, str form, str lemma) + cpdef str apply(self, uint32_t tree_id, str form) + cpdef unicode tree_to_str(self, uint32_t tree_id) + + cdef uint32_t _add(self, str form, str lemma) + cdef _apply(self, uint32_t tree_id, str form_part, list lemma_pieces) + cdef uint32_t _tree_id(self, EditTreeC tree) diff --git a/spacy/pipeline/_edit_tree_internals/edit_trees.pyx b/spacy/pipeline/_edit_tree_internals/edit_trees.pyx new file mode 100644 index 000000000..02907b67a --- /dev/null +++ b/spacy/pipeline/_edit_tree_internals/edit_trees.pyx @@ -0,0 +1,305 @@ +# cython: infer_types=True, binding=True +from cython.operator cimport dereference as deref +from libc.stdint cimport uint32_t +from libc.stdint cimport UINT32_MAX +from libc.string cimport memset +from libcpp.pair cimport pair +from libcpp.vector cimport vector + +from pathlib import Path + +from ...typedefs cimport hash_t + +from ... import util +from ...errors import Errors +from ...strings import StringStore +from .schemas import validate_edit_tree + + +NULL_TREE_ID = UINT32_MAX + +cdef LCS find_lcs(str source, str target): + """ + Find the longest common subsequence (LCS) between two strings. If there are + multiple LCSes, only one of them is returned. + + source (str): The first string. + target (str): The second string. + RETURNS (LCS): The spans of the longest common subsequences. + """ + cdef Py_ssize_t source_len = len(source) + cdef Py_ssize_t target_len = len(target) + cdef size_t longest_align = 0; + cdef int source_idx, target_idx + cdef LCS lcs + cdef Py_UCS4 source_cp, target_cp + + memset(&lcs, 0, sizeof(lcs)) + + cdef vector[size_t] prev_aligns = vector[size_t](target_len); + cdef vector[size_t] cur_aligns = vector[size_t](target_len); + + for (source_idx, source_cp) in enumerate(source): + for (target_idx, target_cp) in enumerate(target): + if source_cp == target_cp: + if source_idx == 0 or target_idx == 0: + cur_aligns[target_idx] = 1 + else: + cur_aligns[target_idx] = prev_aligns[target_idx - 1] + 1 + + # Check if this is the longest alignment and replace previous + # best alignment when this is the case. + if cur_aligns[target_idx] > longest_align: + longest_align = cur_aligns[target_idx] + lcs.source_begin = source_idx - longest_align + 1 + lcs.source_end = source_idx + 1 + lcs.target_begin = target_idx - longest_align + 1 + lcs.target_end = target_idx + 1 + else: + # No match, we start with a zero-length alignment. + cur_aligns[target_idx] = 0 + swap(prev_aligns, cur_aligns) + + return lcs + +cdef class EditTrees: + """Container for constructing and storing edit trees.""" + def __init__(self, strings: StringStore): + """Create a container for edit trees. + + strings (StringStore): the string store to use.""" + self.strings = strings + + cpdef uint32_t add(self, str form, str lemma): + """Add an edit tree that rewrites the given string into the given lemma. + + RETURNS (int): identifier of the edit tree in the container. + """ + # Treat two empty strings as a special case. Generating an edit + # tree for identical strings results in a match node. However, + # since two empty strings have a zero-length LCS, a substitution + # node would be created. Since we do not want to clutter the + # recursive tree construction with logic for this case, handle + # it in this wrapper method. + if len(form) == 0 and len(lemma) == 0: + tree = edittree_new_match(0, 0, NULL_TREE_ID, NULL_TREE_ID) + return self._tree_id(tree) + + return self._add(form, lemma) + + cdef uint32_t _add(self, str form, str lemma): + cdef LCS lcs = find_lcs(form, lemma) + + cdef EditTreeC tree + cdef uint32_t tree_id, prefix_tree, suffix_tree + if lcs_is_empty(lcs): + tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma)) + else: + # If we have a non-empty LCS, such as "gooi" in "ge[gooi]d" and "[gooi]en", + # create edit trees for the prefix pair ("ge"/"") and the suffix pair ("d"/"en"). + prefix_tree = NULL_TREE_ID + if lcs.source_begin != 0 or lcs.target_begin != 0: + prefix_tree = self.add(form[:lcs.source_begin], lemma[:lcs.target_begin]) + + suffix_tree = NULL_TREE_ID + if lcs.source_end != len(form) or lcs.target_end != len(lemma): + suffix_tree = self.add(form[lcs.source_end:], lemma[lcs.target_end:]) + + tree = edittree_new_match(lcs.source_begin, len(form) - lcs.source_end, prefix_tree, suffix_tree) + + return self._tree_id(tree) + + cdef uint32_t _tree_id(self, EditTreeC tree): + # If this tree has been constructed before, return its identifier. + cdef hash_t hash = edittree_hash(tree) + cdef unordered_map[hash_t, uint32_t].iterator iter = self.map.find(hash) + if iter != self.map.end(): + return deref(iter).second + + # The tree hasn't been seen before, store it. + cdef uint32_t tree_id = self.trees.size() + self.trees.push_back(tree) + self.map.insert(pair[hash_t, uint32_t](hash, tree_id)) + + return tree_id + + cpdef str apply(self, uint32_t tree_id, str form): + """Apply an edit tree to a form. + + tree_id (uint32_t): the identifier of the edit tree to apply. + form (str): the form to apply the edit tree to. + RETURNS (str): the transformer form or None if the edit tree + could not be applied to the form. + """ + if tree_id >= self.trees.size(): + raise IndexError("Edit tree identifier out of range") + + lemma_pieces = [] + try: + self._apply(tree_id, form, lemma_pieces) + except ValueError: + return None + return "".join(lemma_pieces) + + cdef _apply(self, uint32_t tree_id, str form_part, list lemma_pieces): + """Recursively apply an edit tree to a form, adding pieces to + the lemma_pieces list.""" + assert tree_id <= self.trees.size() + + cdef EditTreeC tree = self.trees[tree_id] + cdef MatchNodeC match_node + cdef int suffix_start + + if tree.is_match_node: + match_node = tree.inner.match_node + + if match_node.prefix_len + match_node.suffix_len > len(form_part): + raise ValueError("Edit tree cannot be applied to form") + + suffix_start = len(form_part) - match_node.suffix_len + + if match_node.prefix_tree != NULL_TREE_ID: + self._apply(match_node.prefix_tree, form_part[:match_node.prefix_len], lemma_pieces) + + lemma_pieces.append(form_part[match_node.prefix_len:suffix_start]) + + if match_node.suffix_tree != NULL_TREE_ID: + self._apply(match_node.suffix_tree, form_part[suffix_start:], lemma_pieces) + else: + if form_part == self.strings[tree.inner.subst_node.orig]: + lemma_pieces.append(self.strings[tree.inner.subst_node.subst]) + else: + raise ValueError("Edit tree cannot be applied to form") + + cpdef unicode tree_to_str(self, uint32_t tree_id): + """Return the tree as a string. The tree tree string is formatted + like an S-expression. This is primarily useful for debugging. Match + nodes have the following format: + + (m prefix_len suffix_len prefix_tree suffix_tree) + + Substitution nodes have the following format: + + (s original substitute) + + tree_id (uint32_t): the identifier of the edit tree. + RETURNS (str): the tree as an S-expression. + """ + + if tree_id >= self.trees.size(): + raise IndexError("Edit tree identifier out of range") + + cdef EditTreeC tree = self.trees[tree_id] + cdef SubstNodeC subst_node + + if not tree.is_match_node: + subst_node = tree.inner.subst_node + return f"(s '{self.strings[subst_node.orig]}' '{self.strings[subst_node.subst]}')" + + cdef MatchNodeC match_node = tree.inner.match_node + + prefix_tree = "()" + if match_node.prefix_tree != NULL_TREE_ID: + prefix_tree = self.tree_to_str(match_node.prefix_tree) + + suffix_tree = "()" + if match_node.suffix_tree != NULL_TREE_ID: + suffix_tree = self.tree_to_str(match_node.suffix_tree) + + return f"(m {match_node.prefix_len} {match_node.suffix_len} {prefix_tree} {suffix_tree})" + + def from_json(self, trees: list) -> "EditTrees": + self.trees.clear() + + for tree in trees: + tree = _dict2tree(tree) + self.trees.push_back(tree) + + self._rebuild_tree_map() + + def from_bytes(self, bytes_data: bytes, *) -> "EditTrees": + def deserialize_trees(tree_dicts): + cdef EditTreeC c_tree + for tree_dict in tree_dicts: + c_tree = _dict2tree(tree_dict) + self.trees.push_back(c_tree) + + deserializers = {} + deserializers["trees"] = lambda n: deserialize_trees(n) + util.from_bytes(bytes_data, deserializers, []) + + self._rebuild_tree_map() + + return self + + def to_bytes(self, **kwargs) -> bytes: + tree_dicts = [] + for tree in self.trees: + tree = _tree2dict(tree) + tree_dicts.append(tree) + + serializers = {} + serializers["trees"] = lambda: tree_dicts + + return util.to_bytes(serializers, []) + + def to_disk(self, path, **kwargs) -> "EditTrees": + path = util.ensure_path(path) + with path.open("wb") as file_: + file_.write(self.to_bytes()) + + def from_disk(self, path, **kwargs) -> "EditTrees": + path = util.ensure_path(path) + if path.exists(): + with path.open("rb") as file_: + data = file_.read() + return self.from_bytes(data) + + return self + + def __getitem__(self, idx): + return _tree2dict(self.trees[idx]) + + def __len__(self): + return self.trees.size() + + def _rebuild_tree_map(self): + """Rebuild the tree hash -> tree id mapping""" + cdef EditTreeC c_tree + cdef uint32_t tree_id + cdef hash_t tree_hash + + self.map.clear() + + for tree_id in range(self.trees.size()): + c_tree = self.trees[tree_id] + tree_hash = edittree_hash(c_tree) + self.map.insert(pair[hash_t, uint32_t](tree_hash, tree_id)) + + def __reduce__(self): + return (unpickle_edittrees, (self.strings, self.to_bytes())) + + +def unpickle_edittrees(strings, trees_data): + return EditTrees(strings).from_bytes(trees_data) + + +def _tree2dict(tree): + if tree["is_match_node"]: + tree = tree["inner"]["match_node"] + else: + tree = tree["inner"]["subst_node"] + return(dict(tree)) + +def _dict2tree(tree): + errors = validate_edit_tree(tree) + if errors: + raise ValueError(Errors.E1026.format(errors="\n".join(errors))) + + tree = dict(tree) + if "prefix_len" in tree: + tree = {"is_match_node": True, "inner": {"match_node": tree}} + else: + tree = {"is_match_node": False, "inner": {"subst_node": tree}} + + return tree diff --git a/spacy/pipeline/_edit_tree_internals/schemas.py b/spacy/pipeline/_edit_tree_internals/schemas.py new file mode 100644 index 000000000..c01d0632e --- /dev/null +++ b/spacy/pipeline/_edit_tree_internals/schemas.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, List, Union +from collections import defaultdict +from pydantic import BaseModel, Field, ValidationError +from pydantic.types import StrictBool, StrictInt, StrictStr + + +class MatchNodeSchema(BaseModel): + prefix_len: StrictInt = Field(..., title="Prefix length") + suffix_len: StrictInt = Field(..., title="Suffix length") + prefix_tree: StrictInt = Field(..., title="Prefix tree") + suffix_tree: StrictInt = Field(..., title="Suffix tree") + + class Config: + extra = "forbid" + + +class SubstNodeSchema(BaseModel): + orig: Union[int, StrictStr] = Field(..., title="Original substring") + subst: Union[int, StrictStr] = Field(..., title="Replacement substring") + + class Config: + extra = "forbid" + + +class EditTreeSchema(BaseModel): + __root__: Union[MatchNodeSchema, SubstNodeSchema] + + +def validate_edit_tree(obj: Dict[str, Any]) -> List[str]: + """Validate edit tree. + + obj (Dict[str, Any]): JSON-serializable data to validate. + RETURNS (List[str]): A list of error messages, if available. + """ + try: + EditTreeSchema.parse_obj(obj) + return [] + except ValidationError as e: + errors = e.errors() + data = defaultdict(list) + for error in errors: + err_loc = " -> ".join([str(p) for p in error.get("loc", [])]) + data[err_loc].append(error.get("msg")) + return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] # type: ignore[arg-type] diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py new file mode 100644 index 000000000..54a7030dc --- /dev/null +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -0,0 +1,379 @@ +from typing import cast, Any, Callable, Dict, Iterable, List, Optional +from typing import Sequence, Tuple, Union +from collections import Counter +from copy import deepcopy +from itertools import islice +import numpy as np + +import srsly +from thinc.api import Config, Model, SequenceCategoricalCrossentropy +from thinc.types import Floats2d, Ints1d, Ints2d + +from ._edit_tree_internals.edit_trees import EditTrees +from ._edit_tree_internals.schemas import validate_edit_tree +from .lemmatizer import lemmatizer_score +from .trainable_pipe import TrainablePipe +from ..errors import Errors +from ..language import Language +from ..tokens import Doc +from ..training import Example, validate_examples, validate_get_examples +from ..vocab import Vocab +from .. import util + + +default_model_config = """ +[model] +@architectures = "spacy.Tagger.v2" + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v2" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true +""" +DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"] + + +@Language.factory( + "trainable_lemmatizer", + assigns=["token.lemma"], + requires=[], + default_config={ + "model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL, + "backoff": "orth", + "min_tree_freq": 3, + "overwrite": False, + "top_k": 1, + "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, + }, + default_score_weights={"lemma_acc": 1.0}, +) +def make_edit_tree_lemmatizer( + nlp: Language, + name: str, + model: Model, + backoff: Optional[str], + min_tree_freq: int, + overwrite: bool, + top_k: int, + scorer: Optional[Callable], +): + """Construct an EditTreeLemmatizer component.""" + return EditTreeLemmatizer( + nlp.vocab, + model, + name, + backoff=backoff, + min_tree_freq=min_tree_freq, + overwrite=overwrite, + top_k=top_k, + scorer=scorer, + ) + + +class EditTreeLemmatizer(TrainablePipe): + """ + Lemmatizer that lemmatizes each word using a predicted edit tree. + """ + + def __init__( + self, + vocab: Vocab, + model: Model, + name: str = "trainable_lemmatizer", + *, + backoff: Optional[str] = "orth", + min_tree_freq: int = 3, + overwrite: bool = False, + top_k: int = 1, + scorer: Optional[Callable] = lemmatizer_score, + ): + """ + Construct an edit tree lemmatizer. + + backoff (Optional[str]): backoff to use when the predicted edit trees + are not applicable. Must be an attribute of Token or None (leave the + lemma unset). + min_tree_freq (int): prune trees that are applied less than this + frequency in the training data. + overwrite (bool): overwrite existing lemma annotations. + top_k (int): try to apply at most the k most probable edit trees. + """ + self.vocab = vocab + self.model = model + self.name = name + self.backoff = backoff + self.min_tree_freq = min_tree_freq + self.overwrite = overwrite + self.top_k = top_k + + self.trees = EditTrees(self.vocab.strings) + self.tree2label: Dict[int, int] = {} + + self.cfg: Dict[str, Any] = {"labels": []} + self.scorer = scorer + + def get_loss( + self, examples: Iterable[Example], scores: List[Floats2d] + ) -> Tuple[float, List[Floats2d]]: + validate_examples(examples, "EditTreeLemmatizer.get_loss") + loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) + + truths = [] + for eg in examples: + eg_truths = [] + for (predicted, gold_lemma) in zip( + eg.predicted, eg.get_aligned("LEMMA", as_string=True) + ): + if gold_lemma is None: + label = -1 + else: + tree_id = self.trees.add(predicted.text, gold_lemma) + label = self.tree2label.get(tree_id, 0) + eg_truths.append(label) + + truths.append(eg_truths) + + d_scores, loss = loss_func(scores, truths) # type: ignore + if self.model.ops.xp.isnan(loss): + raise ValueError(Errors.E910.format(name=self.name)) + + return float(loss), d_scores + + def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: + n_docs = len(list(docs)) + if not any(len(doc) for doc in docs): + # Handle cases where there are no tokens in any docs. + n_labels = len(self.cfg["labels"]) + guesses: List[Ints2d] = [ + self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs + ] + assert len(guesses) == n_docs + return guesses + scores = self.model.predict(docs) + assert len(scores) == n_docs + guesses = self._scores2guesses(docs, scores) + assert len(guesses) == n_docs + return guesses + + def _scores2guesses(self, docs, scores): + guesses = [] + for doc, doc_scores in zip(docs, scores): + if self.top_k == 1: + doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1) + else: + doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] + + if not isinstance(doc_guesses, np.ndarray): + doc_guesses = doc_guesses.get() + + doc_compat_guesses = [] + for token, candidates in zip(doc, doc_guesses): + tree_id = -1 + for candidate in candidates: + candidate_tree_id = self.cfg["labels"][candidate] + + if self.trees.apply(candidate_tree_id, token.text) is not None: + tree_id = candidate_tree_id + break + doc_compat_guesses.append(tree_id) + + guesses.append(np.array(doc_compat_guesses)) + + return guesses + + def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): + for i, doc in enumerate(docs): + doc_tree_ids = batch_tree_ids[i] + if hasattr(doc_tree_ids, "get"): + doc_tree_ids = doc_tree_ids.get() + for j, tree_id in enumerate(doc_tree_ids): + if self.overwrite or doc[j].lemma == 0: + # If no applicable tree could be found during prediction, + # the special identifier -1 is used. Otherwise the tree + # is guaranteed to be applicable. + if tree_id == -1: + if self.backoff is not None: + doc[j].lemma = getattr(doc[j], self.backoff) + else: + lemma = self.trees.apply(tree_id, doc[j].text) + doc[j].lemma_ = lemma + + @property + def labels(self) -> Tuple[int, ...]: + """Returns the labels currently added to the component.""" + return tuple(self.cfg["labels"]) + + @property + def hide_labels(self) -> bool: + return True + + @property + def label_data(self) -> Dict: + trees = [] + for tree_id in range(len(self.trees)): + tree = self.trees[tree_id] + if "orig" in tree: + tree["orig"] = self.vocab.strings[tree["orig"]] + if "subst" in tree: + tree["subst"] = self.vocab.strings[tree["subst"]] + trees.append(tree) + return dict(trees=trees, labels=tuple(self.cfg["labels"])) + + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language] = None, + labels: Optional[Dict] = None, + ): + validate_get_examples(get_examples, "EditTreeLemmatizer.initialize") + + if labels is None: + self._labels_from_data(get_examples) + else: + self._add_labels(labels) + + # Sample for the model. + doc_sample = [] + label_sample = [] + for example in islice(get_examples(), 10): + doc_sample.append(example.x) + gold_labels: List[List[float]] = [] + for token in example.reference: + if token.lemma == 0: + gold_label = None + else: + gold_label = self._pair2label(token.text, token.lemma_) + + gold_labels.append( + [ + 1.0 if label == gold_label else 0.0 + for label in self.cfg["labels"] + ] + ) + + gold_labels = cast(Floats2d, gold_labels) + label_sample.append(self.model.ops.asarray(gold_labels, dtype="float32")) + + self._require_labels() + assert len(doc_sample) > 0, Errors.E923.format(name=self.name) + assert len(label_sample) > 0, Errors.E923.format(name=self.name) + + self.model.initialize(X=doc_sample, Y=label_sample) + + def from_bytes(self, bytes_data, *, exclude=tuple()): + deserializers = { + "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), + "model": lambda b: self.model.from_bytes(b), + "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude), + "trees": lambda b: self.trees.from_bytes(b), + } + + util.from_bytes(bytes_data, deserializers, exclude) + + return self + + def to_bytes(self, *, exclude=tuple()): + serializers = { + "cfg": lambda: srsly.json_dumps(self.cfg), + "model": lambda: self.model.to_bytes(), + "vocab": lambda: self.vocab.to_bytes(exclude=exclude), + "trees": lambda: self.trees.to_bytes(), + } + + return util.to_bytes(serializers, exclude) + + def to_disk(self, path, exclude=tuple()): + path = util.ensure_path(path) + serializers = { + "cfg": lambda p: srsly.write_json(p, self.cfg), + "model": lambda p: self.model.to_disk(p), + "vocab": lambda p: self.vocab.to_disk(p, exclude=exclude), + "trees": lambda p: self.trees.to_disk(p), + } + util.to_disk(path, serializers, exclude) + + def from_disk(self, path, exclude=tuple()): + def load_model(p): + try: + with open(p, "rb") as mfile: + self.model.from_bytes(mfile.read()) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserializers = { + "cfg": lambda p: self.cfg.update(srsly.read_json(p)), + "model": load_model, + "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude), + "trees": lambda p: self.trees.from_disk(p), + } + + util.from_disk(path, deserializers, exclude) + return self + + def _add_labels(self, labels: Dict): + if "labels" not in labels: + raise ValueError(Errors.E857.format(name="labels")) + if "trees" not in labels: + raise ValueError(Errors.E857.format(name="trees")) + + self.cfg["labels"] = list(labels["labels"]) + trees = [] + for tree in labels["trees"]: + errors = validate_edit_tree(tree) + if errors: + raise ValueError(Errors.E1026.format(errors="\n".join(errors))) + + tree = dict(tree) + if "orig" in tree: + tree["orig"] = self.vocab.strings[tree["orig"]] + if "orig" in tree: + tree["subst"] = self.vocab.strings[tree["subst"]] + + trees.append(tree) + + self.trees.from_json(trees) + + for label, tree in enumerate(self.labels): + self.tree2label[tree] = label + + def _labels_from_data(self, get_examples: Callable[[], Iterable[Example]]): + # Count corpus tree frequencies in ad-hoc storage to avoid cluttering + # the final pipe/string store. + vocab = Vocab() + trees = EditTrees(vocab.strings) + tree_freqs: Counter = Counter() + repr_pairs: Dict = {} + for example in get_examples(): + for token in example.reference: + if token.lemma != 0: + tree_id = trees.add(token.text, token.lemma_) + tree_freqs[tree_id] += 1 + repr_pairs[tree_id] = (token.text, token.lemma_) + + # Construct trees that make the frequency cut-off using representative + # form - token pairs. + for tree_id, freq in tree_freqs.items(): + if freq >= self.min_tree_freq: + form, lemma = repr_pairs[tree_id] + self._pair2label(form, lemma, add_label=True) + + def _pair2label(self, form, lemma, add_label=False): + """ + Look up the edit tree identifier for a form/label pair. If the edit + tree is unknown and "add_label" is set, the edit tree will be added to + the labels. + """ + tree_id = self.trees.add(form, lemma) + if tree_id not in self.tree2label: + if not add_label: + return None + + self.tree2label[tree_id] = len(self.cfg["labels"]) + self.cfg["labels"].append(tree_id) + return self.tree2label[tree_id] diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py new file mode 100644 index 000000000..cf541e301 --- /dev/null +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -0,0 +1,280 @@ +import pickle +import pytest +from hypothesis import given +import hypothesis.strategies as st +from spacy import util +from spacy.lang.en import English +from spacy.language import Language +from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees +from spacy.training import Example +from spacy.strings import StringStore +from spacy.util import make_tempdir + + +TRAIN_DATA = [ + ("She likes green eggs", {"lemmas": ["she", "like", "green", "egg"]}), + ("Eat blue ham", {"lemmas": ["eat", "blue", "ham"]}), +] + +PARTIAL_DATA = [ + # partial annotation + ("She likes green eggs", {"lemmas": ["", "like", "green", ""]}), + # misaligned partial annotation + ( + "He hates green eggs", + { + "words": ["He", "hat", "es", "green", "eggs"], + "lemmas": ["", "hat", "e", "green", ""], + }, + ), +] + + +def test_initialize_examples(): + nlp = Language() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + # you shouldn't really call this more than once, but for testing it should be fine + nlp.initialize(get_examples=lambda: train_examples) + with pytest.raises(TypeError): + nlp.initialize(get_examples=lambda: None) + with pytest.raises(TypeError): + nlp.initialize(get_examples=lambda: train_examples[0]) + with pytest.raises(TypeError): + nlp.initialize(get_examples=lambda: []) + with pytest.raises(TypeError): + nlp.initialize(get_examples=train_examples) + + +def test_initialize_from_labels(): + nlp = Language() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer.min_tree_freq = 1 + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + nlp.initialize(get_examples=lambda: train_examples) + + nlp2 = Language() + lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") + lemmatizer2.initialize( + get_examples=lambda: train_examples, + labels=lemmatizer.label_data, + ) + assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3} + + +def test_no_data(): + # Test that the lemmatizer provides a nice error when there's no tagging data / labels + TEXTCAT_DATA = [ + ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), + ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), + ] + nlp = English() + nlp.add_pipe("trainable_lemmatizer") + nlp.add_pipe("textcat") + + train_examples = [] + for t in TEXTCAT_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + with pytest.raises(ValueError): + nlp.initialize(get_examples=lambda: train_examples) + + +def test_incomplete_data(): + # Test that the lemmatizer works with incomplete information + nlp = English() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer.min_tree_freq = 1 + train_examples = [] + for t in PARTIAL_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["trainable_lemmatizer"] < 0.00001 + + # test the trained model + test_text = "She likes blue eggs" + doc = nlp(test_text) + assert doc[1].lemma_ == "like" + assert doc[2].lemma_ == "blue" + + +def test_overfitting_IO(): + nlp = English() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer.min_tree_freq = 1 + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + optimizer = nlp.initialize(get_examples=lambda: train_examples) + + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["trainable_lemmatizer"] < 0.00001 + + test_text = "She likes blue eggs" + doc = nlp(test_text) + assert doc[0].lemma_ == "she" + assert doc[1].lemma_ == "like" + assert doc[2].lemma_ == "blue" + assert doc[3].lemma_ == "egg" + + # Check model after a {to,from}_disk roundtrip + with util.make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + assert doc2[0].lemma_ == "she" + assert doc2[1].lemma_ == "like" + assert doc2[2].lemma_ == "blue" + assert doc2[3].lemma_ == "egg" + + # Check model after a {to,from}_bytes roundtrip + nlp_bytes = nlp.to_bytes() + nlp3 = English() + nlp3.add_pipe("trainable_lemmatizer") + nlp3.from_bytes(nlp_bytes) + doc3 = nlp3(test_text) + assert doc3[0].lemma_ == "she" + assert doc3[1].lemma_ == "like" + assert doc3[2].lemma_ == "blue" + assert doc3[3].lemma_ == "egg" + + # Check model after a pickle roundtrip. + nlp_bytes = pickle.dumps(nlp) + nlp4 = pickle.loads(nlp_bytes) + doc4 = nlp4(test_text) + assert doc4[0].lemma_ == "she" + assert doc4[1].lemma_ == "like" + assert doc4[2].lemma_ == "blue" + assert doc4[3].lemma_ == "egg" + + +def test_lemmatizer_requires_labels(): + nlp = English() + nlp.add_pipe("trainable_lemmatizer") + with pytest.raises(ValueError): + nlp.initialize() + + +def test_lemmatizer_label_data(): + nlp = English() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer.min_tree_freq = 1 + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + nlp.initialize(get_examples=lambda: train_examples) + + nlp2 = English() + lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") + lemmatizer2.initialize( + get_examples=lambda: train_examples, labels=lemmatizer.label_data + ) + + # Verify that the labels and trees are the same. + assert lemmatizer.labels == lemmatizer2.labels + assert lemmatizer.trees.to_bytes() == lemmatizer2.trees.to_bytes() + + +def test_dutch(): + strings = StringStore() + trees = EditTrees(strings) + tree = trees.add("deelt", "delen") + assert trees.tree_to_str(tree) == "(m 0 3 () (m 0 2 (s '' 'l') (s 'lt' 'n')))" + + tree = trees.add("gedeeld", "delen") + assert ( + trees.tree_to_str(tree) == "(m 2 3 (s 'ge' '') (m 0 2 (s '' 'l') (s 'ld' 'n')))" + ) + + +def test_from_to_bytes(): + strings = StringStore() + trees = EditTrees(strings) + trees.add("deelt", "delen") + trees.add("gedeeld", "delen") + + b = trees.to_bytes() + + trees2 = EditTrees(strings) + trees2.from_bytes(b) + + # Verify that the nodes did not change. + assert len(trees) == len(trees2) + for i in range(len(trees)): + assert trees.tree_to_str(i) == trees2.tree_to_str(i) + + # Reinserting the same trees should not add new nodes. + trees2.add("deelt", "delen") + trees2.add("gedeeld", "delen") + assert len(trees) == len(trees2) + + +def test_from_to_disk(): + strings = StringStore() + trees = EditTrees(strings) + trees.add("deelt", "delen") + trees.add("gedeeld", "delen") + + trees2 = EditTrees(strings) + with make_tempdir() as temp_dir: + trees_file = temp_dir / "edit_trees.bin" + trees.to_disk(trees_file) + trees2 = trees2.from_disk(trees_file) + + # Verify that the nodes did not change. + assert len(trees) == len(trees2) + for i in range(len(trees)): + assert trees.tree_to_str(i) == trees2.tree_to_str(i) + + # Reinserting the same trees should not add new nodes. + trees2.add("deelt", "delen") + trees2.add("gedeeld", "delen") + assert len(trees) == len(trees2) + + +@given(st.text(), st.text()) +def test_roundtrip(form, lemma): + strings = StringStore() + trees = EditTrees(strings) + tree = trees.add(form, lemma) + assert trees.apply(tree, form) == lemma + + +@given(st.text(alphabet="ab"), st.text(alphabet="ab")) +def test_roundtrip_small_alphabet(form, lemma): + # Test with small alphabets to have more overlap. + strings = StringStore() + trees = EditTrees(strings) + tree = trees.add(form, lemma) + assert trees.apply(tree, form) == lemma + + +def test_unapplicable_trees(): + strings = StringStore() + trees = EditTrees(strings) + tree3 = trees.add("deelt", "delen") + + # Replacement fails. + assert trees.apply(tree3, "deeld") == None + + # Suffix + prefix are too large. + assert trees.apply(tree3, "de") == None + + +def test_empty_strings(): + strings = StringStore() + trees = EditTrees(strings) + no_change = trees.add("xyz", "xyz") + empty = trees.add("", "") + assert no_change == empty diff --git a/website/docs/api/edittreelemmatizer.md b/website/docs/api/edittreelemmatizer.md new file mode 100644 index 000000000..99a705f5e --- /dev/null +++ b/website/docs/api/edittreelemmatizer.md @@ -0,0 +1,409 @@ +--- +title: EditTreeLemmatizer +tag: class +source: spacy/pipeline/edit_tree_lemmatizer.py +new: 3.3 +teaser: 'Pipeline component for lemmatization' +api_base_class: /api/pipe +api_string_name: trainable_lemmatizer +api_trainable: true +--- + +A trainable component for assigning base forms to tokens. This lemmatizer uses +**edit trees** to transform tokens into base forms. The lemmatization model +predicts which edit tree is applicable to a token. The edit tree data structure +and construction method used by this lemmatizer were proposed in +[Joint Lemmatization and Morphological Tagging with Lemming](https://aclanthology.org/D15-1272.pdf) +(Thomas Müller et al., 2015). + +For a lookup and rule-based lemmatizer, see [`Lemmatizer`](/api/lemmatizer). + +## Assigned Attributes {#assigned-attributes} + +Predictions are assigned to `Token.lemma`. + +| Location | Value | +| -------------- | ------------------------- | +| `Token.lemma` | The lemma (hash). ~~int~~ | +| `Token.lemma_` | The lemma. ~~str~~ | + +## Config and implementation {#config} + +The default config is defined by the pipeline component factory and describes +how the component should be configured. You can override its settings via the +`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your +[`config.cfg` for training](/usage/training#config). See the +[model architectures](/api/architectures) documentation for details on the +architectures and their arguments and hyperparameters. + +> #### Example +> +> ```python +> from spacy.pipeline.edit_tree_lemmatizer import DEFAULT_EDIT_TREE_LEMMATIZER_MODEL +> config = {"model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL} +> nlp.add_pipe("trainable_lemmatizer", config=config, name="lemmatizer") +> ``` + +| Setting | Description | +| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `model` | A model instance that predicts the edit tree probabilities. The output vectors should match the number of edit trees in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ | +| `backoff` | ~~Token~~ attribute to use when no applicable edit tree is found. Defaults to `orth`. ~~str~~ | +| `min_tree_freq` | Minimum frequency of an edit tree in the training set to be used. Defaults to `3`. ~~int~~ | +| `overwrite` | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | +| `top_k` | The number of most probable edit trees to try before resorting to `backoff`. Defaults to `1`. ~~int~~ | +| `scorer` | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"lemma"`. ~~Optional[Callable]~~ | + +```python +%%GITHUB_SPACY/spacy/pipeline/edit_tree_lemmatizer.py +``` + +## EditTreeLemmatizer.\_\_init\_\_ {#init tag="method"} + +> #### Example +> +> ```python +> # Construction via add_pipe with default model +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> +> # Construction via create_pipe with custom model +> config = {"model": {"@architectures": "my_tagger"}} +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", config=config, name="lemmatizer") +> +> # Construction from class +> from spacy.pipeline import EditTreeLemmatizer +> lemmatizer = EditTreeLemmatizer(nlp.vocab, model) +> ``` + +Create a new pipeline instance. In your application, you would normally use a +shortcut for this and instantiate the component using its string name and +[`nlp.add_pipe`](/api/language#add_pipe). + +| Name | Description | +| --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `vocab` | The shared vocabulary. ~~Vocab~~ | +| `model` | A model instance that predicts the edit tree probabilities. The output vectors should match the number of edit trees in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). ~~Model[List[Doc], List[Floats2d]]~~ | +| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | +| _keyword-only_ | | +| `backoff` | ~~Token~~ attribute to use when no applicable edit tree is found. Defaults to `orth`. ~~str~~ | +| `min_tree_freq` | Minimum frequency of an edit tree in the training set to be used. Defaults to `3`. ~~int~~ | +| `overwrite` | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | +| `top_k` | The number of most probable edit trees to try before resorting to `backoff`. Defaults to `1`. ~~int~~ | +| `scorer` | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"lemma"`. ~~Optional[Callable]~~ | + +## EditTreeLemmatizer.\_\_call\_\_ {#call tag="method"} + +Apply the pipe to one document. The document is modified in place, and returned. +This usually happens under the hood when the `nlp` object is called on a text +and all pipeline components are applied to the `Doc` in order. Both +[`__call__`](/api/edittreelemmatizer#call) and +[`pipe`](/api/edittreelemmatizer#pipe) delegate to the +[`predict`](/api/edittreelemmatizer#predict) and +[`set_annotations`](/api/edittreelemmatizer#set_annotations) methods. + +> #### Example +> +> ```python +> doc = nlp("This is a sentence.") +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> # This usually happens under the hood +> processed = lemmatizer(doc) +> ``` + +| Name | Description | +| ----------- | -------------------------------- | +| `doc` | The document to process. ~~Doc~~ | +| **RETURNS** | The processed document. ~~Doc~~ | + +## EditTreeLemmatizer.pipe {#pipe tag="method"} + +Apply the pipe to a stream of documents. This usually happens under the hood +when the `nlp` object is called on a text and all pipeline components are +applied to the `Doc` in order. Both [`__call__`](/api/edittreelemmatizer#call) +and [`pipe`](/api/edittreelemmatizer#pipe) delegate to the +[`predict`](/api/edittreelemmatizer#predict) and +[`set_annotations`](/api/edittreelemmatizer#set_annotations) methods. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> for doc in lemmatizer.pipe(docs, batch_size=50): +> pass +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------- | +| `stream` | A stream of documents. ~~Iterable[Doc]~~ | +| _keyword-only_ | | +| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ | +| **YIELDS** | The processed documents in order. ~~Doc~~ | + +## EditTreeLemmatizer.initialize {#initialize tag="method" new="3"} + +Initialize the component for training. `get_examples` should be a function that +returns an iterable of [`Example`](/api/example) objects. The data examples are +used to **initialize the model** of the component and can either be the full +training data or a representative sample. Initialization includes validating the +network, +[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and +setting up the label scheme based on the data. This method is typically called +by [`Language.initialize`](/api/language#initialize) and lets you customize +arguments it receives via the +[`[initialize.components]`](/api/data-formats#config-initialize) block in the +config. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> lemmatizer.initialize(lambda: [], nlp=nlp) +> ``` +> +> ```ini +> ### config.cfg +> [initialize.components.lemmatizer] +> +> [initialize.components.lemmatizer.labels] +> @readers = "spacy.read_labels.v1" +> path = "corpus/labels/lemmatizer.json +> ``` + +| Name | Description | +| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ | +| _keyword-only_ | | +| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ | +| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ | + +## EditTreeLemmatizer.predict {#predict tag="method"} + +Apply the component's model to a batch of [`Doc`](/api/doc) objects, without +modifying them. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> tree_ids = lemmatizer.predict([doc1, doc2]) +> ``` + +| Name | Description | +| ----------- | ------------------------------------------- | +| `docs` | The documents to predict. ~~Iterable[Doc]~~ | +| **RETURNS** | The model's prediction for each document. | + +## EditTreeLemmatizer.set_annotations {#set_annotations tag="method"} + +Modify a batch of [`Doc`](/api/doc) objects, using pre-computed tree +identifiers. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> tree_ids = lemmatizer.predict([doc1, doc2]) +> lemmatizer.set_annotations([doc1, doc2], tree_ids) +> ``` + +| Name | Description | +| ---------- | ------------------------------------------------------------------------------------- | +| `docs` | The documents to modify. ~~Iterable[Doc]~~ | +| `tree_ids` | The identifiers of the edit trees to apply, produced by `EditTreeLemmatizer.predict`. | + +## EditTreeLemmatizer.update {#update tag="method"} + +Learn from a batch of [`Example`](/api/example) objects containing the +predictions and gold-standard annotations, and update the component's model. +Delegates to [`predict`](/api/edittreelemmatizer#predict) and +[`get_loss`](/api/edittreelemmatizer#get_loss). + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> optimizer = nlp.initialize() +> losses = lemmatizer.update(examples, sgd=optimizer) +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------ | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | + +## EditTreeLemmatizer.get_loss {#get_loss tag="method"} + +Find the loss and gradient of loss for the batch of documents and their +predicted scores. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> scores = lemmatizer.model.begin_update([eg.predicted for eg in examples]) +> loss, d_loss = lemmatizer.get_loss(examples, scores) +> ``` + +| Name | Description | +| ----------- | --------------------------------------------------------------------------- | +| `examples` | The batch of examples. ~~Iterable[Example]~~ | +| `scores` | Scores representing the model's predictions. | +| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | + +## EditTreeLemmatizer.create_optimizer {#create_optimizer tag="method"} + +Create an optimizer for the pipeline component. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> optimizer = lemmatizer.create_optimizer() +> ``` + +| Name | Description | +| ----------- | ---------------------------- | +| **RETURNS** | The optimizer. ~~Optimizer~~ | + +## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"} + +Modify the pipe's model, to use the given parameter values. At the end of the +context, the original parameters are restored. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> with lemmatizer.use_params(optimizer.averages): +> lemmatizer.to_disk("/best_model") +> ``` + +| Name | Description | +| -------- | -------------------------------------------------- | +| `params` | The parameter values to use in the model. ~~dict~~ | + +## EditTreeLemmatizer.to_disk {#to_disk tag="method"} + +Serialize the pipe to disk. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> lemmatizer.to_disk("/path/to/lemmatizer") +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | + +## EditTreeLemmatizer.from_disk {#from_disk tag="method"} + +Load the pipe from disk. Modifies the object in place and returns it. + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> lemmatizer.from_disk("/path/to/lemmatizer") +> ``` + +| Name | Description | +| -------------- | ----------------------------------------------------------------------------------------------- | +| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | +| **RETURNS** | The modified `EditTreeLemmatizer` object. ~~EditTreeLemmatizer~~ | + +## EditTreeLemmatizer.to_bytes {#to_bytes tag="method"} + +> #### Example +> +> ```python +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> lemmatizer_bytes = lemmatizer.to_bytes() +> ``` + +Serialize the pipe to a bytestring. + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------- | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | +| **RETURNS** | The serialized form of the `EditTreeLemmatizer` object. ~~bytes~~ | + +## EditTreeLemmatizer.from_bytes {#from_bytes tag="method"} + +Load the pipe from a bytestring. Modifies the object in place and returns it. + +> #### Example +> +> ```python +> lemmatizer_bytes = lemmatizer.to_bytes() +> lemmatizer = nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +> lemmatizer.from_bytes(lemmatizer_bytes) +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------- | +| `bytes_data` | The data to load from. ~~bytes~~ | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | +| **RETURNS** | The `EditTreeLemmatizer` object. ~~EditTreeLemmatizer~~ | + +## EditTreeLemmatizer.labels {#labels tag="property"} + +The labels currently added to the component. + + + +The `EditTreeLemmatizer` labels are not useful by themselves, since they are +identifiers of edit trees. + + + +| Name | Description | +| ----------- | ------------------------------------------------------ | +| **RETURNS** | The labels added to the component. ~~Tuple[str, ...]~~ | + +## EditTreeLemmatizer.label_data {#label_data tag="property" new="3"} + +The labels currently added to the component and their internal meta information. +This is the data generated by [`init labels`](/api/cli#init-labels) and used by +[`EditTreeLemmatizer.initialize`](/api/edittreelemmatizer#initialize) to +initialize the model with a pre-defined label set. + +> #### Example +> +> ```python +> labels = lemmatizer.label_data +> lemmatizer.initialize(lambda: [], nlp=nlp, labels=labels) +> ``` + +| Name | Description | +| ----------- | ---------------------------------------------------------- | +| **RETURNS** | The label data added to the component. ~~Tuple[str, ...]~~ | + +## Serialization fields {#serialization-fields} + +During serialization, spaCy will export several data fields used to restore +different aspects of the object. If needed, you can exclude them from +serialization by passing in the string names via the `exclude` argument. + +> #### Example +> +> ```python +> data = lemmatizer.to_disk("/path", exclude=["vocab"]) +> ``` + +| Name | Description | +| ------- | -------------------------------------------------------------- | +| `vocab` | The shared [`Vocab`](/api/vocab). | +| `cfg` | The config file. You usually don't want to exclude this. | +| `model` | The binary model data. You usually don't want to exclude this. | +| `trees` | The edit trees. You usually don't want to exclude this. | diff --git a/website/docs/api/lemmatizer.md b/website/docs/api/lemmatizer.md index 2fa040917..75387305a 100644 --- a/website/docs/api/lemmatizer.md +++ b/website/docs/api/lemmatizer.md @@ -9,14 +9,15 @@ api_trainable: false --- Component for assigning base forms to tokens using rules based on part-of-speech -tags, or lookup tables. Functionality to train the component is coming soon. -Different [`Language`](/api/language) subclasses can implement their own -lemmatizer components via +tags, or lookup tables. Different [`Language`](/api/language) subclasses can +implement their own lemmatizer components via [language-specific factories](/usage/processing-pipelines#factories-language). The default data used is provided by the [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data) extension package. +For a trainable lemmatizer, see [`EditTreeLemmatizer`](/api/edittreelemmatizer). + As of v3.0, the `Lemmatizer` is a **standalone pipeline component** that can be diff --git a/website/docs/usage/101/_architecture.md b/website/docs/usage/101/_architecture.md index 8fb452895..22e2b961e 100644 --- a/website/docs/usage/101/_architecture.md +++ b/website/docs/usage/101/_architecture.md @@ -45,10 +45,11 @@ components for different language processing tasks and also allows adding | ----------------------------------------------- | ------------------------------------------------------------------------------------------- | | [`AttributeRuler`](/api/attributeruler) | Set token attributes using matcher rules. | | [`DependencyParser`](/api/dependencyparser) | Predict syntactic dependencies. | +| [`EditTreeLemmatizer`](/api/edittreelemmatizer) | Predict base forms of words. | | [`EntityLinker`](/api/entitylinker) | Disambiguate named entities to nodes in a knowledge base. | | [`EntityRecognizer`](/api/entityrecognizer) | Predict named entities, e.g. persons or products. | | [`EntityRuler`](/api/entityruler) | Add entity spans to the `Doc` using token-based rules or exact phrase matches. | -| [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words. | +| [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words using rules and lookups. | | [`Morphologizer`](/api/morphologizer) | Predict morphological features and coarse-grained part-of-speech tags. | | [`SentenceRecognizer`](/api/sentencerecognizer) | Predict sentence boundaries. | | [`Sentencizer`](/api/sentencizer) | Implement rule-based sentence boundary detection that doesn't require the dependency parse. | diff --git a/website/docs/usage/linguistic-features.md b/website/docs/usage/linguistic-features.md index c3f25565a..b3b896a54 100644 --- a/website/docs/usage/linguistic-features.md +++ b/website/docs/usage/linguistic-features.md @@ -120,10 +120,13 @@ print(doc[2].pos_) # 'PRON' ## Lemmatization {#lemmatization model="lemmatizer" new="3"} -The [`Lemmatizer`](/api/lemmatizer) is a pipeline component that provides lookup -and rule-based lemmatization methods in a configurable component. An individual -language can extend the `Lemmatizer` as part of its -[language data](#language-data). +spaCy provides two pipeline components for lemmatization: + +1. The [`Lemmatizer`](/api/lemmatizer) component provides lookup and rule-based + lemmatization methods in a configurable component. An individual language can + extend the `Lemmatizer` as part of its [language data](#language-data). +2. The [`EditTreeLemmatizer`](/api/edittreelemmatizer) + 3.3 component provides a trainable lemmatizer. ```python ### {executable="true"} @@ -197,6 +200,20 @@ information, without consulting the context of the token. The rule-based lemmatizer also accepts list-based exception files. For English, these are acquired from [WordNet](https://wordnet.princeton.edu/). +### Trainable lemmatizer + +The [`EditTreeLemmatizer`](/api/edittreelemmatizer) can learn form-to-lemma +transformations from a training corpus that includes lemma annotations. This +removes the need to write language-specific rules and can (in many cases) +provide higher accuracies than lookup and rule-based lemmatizers. + +```python +import spacy + +nlp = spacy.blank("de") +nlp.add_pipe("trainable_lemmatizer", name="lemmatizer") +``` + ## Dependency Parsing {#dependency-parse model="parser"} spaCy features a fast and accurate syntactic dependency parser, and has a rich @@ -1189,7 +1206,7 @@ class WhitespaceTokenizer: spaces = spaces[0:-1] else: spaces[-1] = False - + return Doc(self.vocab, words=words, spaces=spaces) nlp = spacy.blank("en") @@ -1269,8 +1286,8 @@ hyperparameters, pipeline and tokenizer used for constructing and training the pipeline. The `[nlp.tokenizer]` block refers to a **registered function** that takes the `nlp` object and returns a tokenizer. Here, we're registering a function called `whitespace_tokenizer` in the -[`@tokenizers` registry](/api/top-level#registry). To make sure spaCy knows how to -construct your tokenizer during training, you can pass in your Python file by +[`@tokenizers` registry](/api/top-level#registry). To make sure spaCy knows how +to construct your tokenizer during training, you can pass in your Python file by setting `--code functions.py` when you run [`spacy train`](/api/cli#train). > #### config.cfg diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 9e6ee54df..4f75b5193 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -303,22 +303,23 @@ available pipeline components and component functions. > ruler = nlp.add_pipe("entity_ruler") > ``` -| String name | Component | Description | -| -------------------- | ---------------------------------------------------- | ----------------------------------------------------------------------------------------- | -| `tagger` | [`Tagger`](/api/tagger) | Assign part-of-speech-tags. | -| `parser` | [`DependencyParser`](/api/dependencyparser) | Assign dependency labels. | -| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. | -| `entity_linker` | [`EntityLinker`](/api/entitylinker) | Assign knowledge base IDs to named entities. Should be added after the entity recognizer. | -| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules and dictionaries. | -| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories: exactly one category is predicted per document. | -| `textcat_multilabel` | [`MultiLabel_TextCategorizer`](/api/textcategorizer) | Assign text categories in a multi-label setting: zero, one or more labels per document. | -| `lemmatizer` | [`Lemmatizer`](/api/lemmatizer) | Assign base forms to words. | -| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. | -| `attribute_ruler` | [`AttributeRuler`](/api/attributeruler) | Assign token attribute mappings and rule-based exceptions. | -| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. | -| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. | -| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | Assign token-to-vector embeddings. | -| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. | +| String name | Component | Description | +| ---------------------- | ---------------------------------------------------- | ----------------------------------------------------------------------------------------- | +| `tagger` | [`Tagger`](/api/tagger) | Assign part-of-speech-tags. | +| `parser` | [`DependencyParser`](/api/dependencyparser) | Assign dependency labels. | +| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. | +| `entity_linker` | [`EntityLinker`](/api/entitylinker) | Assign knowledge base IDs to named entities. Should be added after the entity recognizer. | +| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules and dictionaries. | +| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories: exactly one category is predicted per document. | +| `textcat_multilabel` | [`MultiLabel_TextCategorizer`](/api/textcategorizer) | Assign text categories in a multi-label setting: zero, one or more labels per document. | +| `lemmatizer` | [`Lemmatizer`](/api/lemmatizer) | Assign base forms to words using rules and lookups. | +| `trainable_lemmatizer` | [`EditTreeLemmatizer`](/api/edittreelemmatizer) | Assign base forms to words. | +| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. | +| `attribute_ruler` | [`AttributeRuler`](/api/attributeruler) | Assign token attribute mappings and rule-based exceptions. | +| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. | +| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. | +| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | Assign token-to-vector embeddings. | +| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. | ### Disabling, excluding and modifying components {#disabling} diff --git a/website/meta/sidebars.json b/website/meta/sidebars.json index c49b49c73..2229c91f3 100644 --- a/website/meta/sidebars.json +++ b/website/meta/sidebars.json @@ -93,6 +93,7 @@ "items": [ { "text": "AttributeRuler", "url": "/api/attributeruler" }, { "text": "DependencyParser", "url": "/api/dependencyparser" }, + { "text": "EditTreeLemmatizer", "url": "/api/edittreelemmatizer" }, { "text": "EntityLinker", "url": "/api/entitylinker" }, { "text": "EntityRecognizer", "url": "/api/entityrecognizer" }, { "text": "EntityRuler", "url": "/api/entityruler" },