diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 95f9c66df..120e9b02c 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union from thinc.types import Floats2d from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed @@ -10,7 +10,7 @@ from ...ml import _character_embed from ..staticvectors import StaticVectors from ..featureextractor import FeatureExtractor from ...pipeline.tok2vec import Tok2VecListener -from ...attrs import ORTH, NORM, PREFIX, SUFFIX, SHAPE +from ...attrs import ORTH, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr @registry.architectures.register("spacy.Tok2VecListener.v1") @@ -98,7 +98,7 @@ def MultiHashEmbed( attributes using hash embedding, concatenates the results, and passes it through a feed-forward subnetwork to build a mixed representations. - The features used are the NORM, PREFIX, SUFFIX and SHAPE, which can have + The features used are the LOWER, PREFIX, SUFFIX and SHAPE, which can have varying definitions depending on the Vocab of the Doc object passed in. Vectors from pretrained static vectors can also be incorporated into the concatenated representation. @@ -115,7 +115,7 @@ def MultiHashEmbed( also_use_static_vectors (bool): Whether to also use static word vectors. Requires a vectors table to be loaded in the Doc objects' vocab. """ - cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] + cols = [LOWER, PREFIX, SUFFIX, SHAPE, ORTH] seed = 7 def make_hash_embed(feature): @@ -123,7 +123,7 @@ def MultiHashEmbed( seed += 1 return HashEmbed( width, - rows if feature == NORM else rows // 2, + rows if feature == LOWER else rows // 2, column=cols.index(feature), seed=seed, dropout=0.0, @@ -131,13 +131,13 @@ def MultiHashEmbed( if also_embed_subwords: embeddings = [ - make_hash_embed(NORM), + make_hash_embed(LOWER), make_hash_embed(PREFIX), make_hash_embed(SUFFIX), make_hash_embed(SHAPE), ] else: - embeddings = [make_hash_embed(NORM)] + embeddings = [make_hash_embed(LOWER)] concat_size = width * (len(embeddings) + also_use_static_vectors) if also_use_static_vectors: model = chain( @@ -165,7 +165,8 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed( - width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool + width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool, + feature: Union[int, str]="LOWER" ): """Construct an embedded representation based on character embeddings, using a feed-forward network. A fixed number of UTF-8 byte characters are used for @@ -179,12 +180,13 @@ def CharacterEmbed( of being in an arbitrary position depending on the word length. The characters are embedded in a embedding table with a given number of rows, - and the vectors concatenated. A hash-embedded vector of the NORM of the word is + and the vectors concatenated. A hash-embedded vector of the LOWER of the word is also concatenated on, and the result is then passed through a feed-forward network to construct a single vector to represent the information. - width (int): The width of the output vector and the NORM hash embedding. - rows (int): The number of rows in the NORM hash embedding table. + feature (int or str): An attribute to embed, to concatenate with the characters. + width (int): The width of the output vector and the feature embedding. + rows (int): The number of rows in the LOWER hash embedding table. nM (int): The dimensionality of the character embeddings. Recommended values are between 16 and 64. nC (int): The number of UTF-8 bytes to embed per word. Recommended values @@ -193,12 +195,15 @@ def CharacterEmbed( also_use_static_vectors (bool): Whether to also use static word vectors. Requires a vectors table to be loaded in the Doc objects' vocab. """ + feature = intify_attr(feature) + if feature is None: + raise ValueError("Invalid feature: Must be a token attribute.") if also_use_static_vectors: model = chain( concatenate( chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain( - FeatureExtractor([NORM]), + FeatureExtractor([feature]), list2ragged(), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), ), @@ -214,7 +219,7 @@ def CharacterEmbed( concatenate( chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain( - FeatureExtractor([NORM]), + FeatureExtractor([feature]), list2ragged(), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), ),