diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index fec478e21..888dc9caa 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import Model, noop, list2ragged, ragged2list from thinc.api import FeatureExtractor, HashEmbed @@ -165,7 +165,8 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed( - width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool + width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool, + feature: Union[int, str]="NORM" ): """Construct an embedded representation based on character embeddings, using a feed-forward network. A fixed number of UTF-8 byte characters are used for @@ -183,7 +184,8 @@ def CharacterEmbed( also concatenated on, and the result is then passed through a feed-forward network to construct a single vector to represent the information. - width (int): The width of the output vector and the NORM hash embedding. + feature (int or str): An attribute to embed, to concatenate with the characters. + width (int): The width of the output vector and the feature embedding. rows (int): The number of rows in the NORM hash embedding table. nM (int): The dimensionality of the character embeddings. Recommended values are between 16 and 64. @@ -193,12 +195,15 @@ def CharacterEmbed( also_use_static_vectors (bool): Whether to also use static word vectors. Requires a vectors table to be loaded in the Doc objects' vocab. """ + feature = intify_attr(feature) + if feature is None: + raise ValueError("Invalid feature: Must be a token attribute.") if also_use_static_vectors: model = chain( concatenate( chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain( - FeatureExtractor([NORM]), + FeatureExtractor([feature]), list2ragged(), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), ), @@ -214,7 +219,7 @@ def CharacterEmbed( concatenate( chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain( - FeatureExtractor([NORM]), + FeatureExtractor([feature]), list2ragged(), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), ),