Allow CharacterEmbed to specify feature

This commit is contained in:
Matthew Honnibal 2020-10-01 22:17:26 +02:00
parent 0a8a124a6e
commit 684a77870b
1 changed files with 10 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List from typing import Optional, List, Union
from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list from thinc.api import Model, noop, list2ragged, ragged2list
from thinc.api import FeatureExtractor, HashEmbed from thinc.api import FeatureExtractor, HashEmbed
@ -165,7 +165,8 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed( def CharacterEmbed(
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
feature: Union[int, str]="NORM"
): ):
"""Construct an embedded representation based on character embeddings, using """Construct an embedded representation based on character embeddings, using
a feed-forward network. A fixed number of UTF-8 byte characters are used for a feed-forward network. A fixed number of UTF-8 byte characters are used for
@ -183,7 +184,8 @@ def CharacterEmbed(
also concatenated on, and the result is then passed through a feed-forward also concatenated on, and the result is then passed through a feed-forward
network to construct a single vector to represent the information. network to construct a single vector to represent the information.
width (int): The width of the output vector and the NORM hash embedding. feature (int or str): An attribute to embed, to concatenate with the characters.
width (int): The width of the output vector and the feature embedding.
rows (int): The number of rows in the NORM hash embedding table. rows (int): The number of rows in the NORM hash embedding table.
nM (int): The dimensionality of the character embeddings. Recommended values nM (int): The dimensionality of the character embeddings. Recommended values
are between 16 and 64. are between 16 and 64.
@ -193,12 +195,15 @@ def CharacterEmbed(
also_use_static_vectors (bool): Whether to also use static word vectors. also_use_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab. Requires a vectors table to be loaded in the Doc objects' vocab.
""" """
feature = intify_attr(feature)
if feature is None:
raise ValueError("Invalid feature: Must be a token attribute.")
if also_use_static_vectors: if also_use_static_vectors:
model = chain( model = chain(
concatenate( concatenate(
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
chain( chain(
FeatureExtractor([NORM]), FeatureExtractor([feature]),
list2ragged(), list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
), ),
@ -214,7 +219,7 @@ def CharacterEmbed(
concatenate( concatenate(
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
chain( chain(
FeatureExtractor([NORM]), FeatureExtractor([feature]),
list2ragged(), list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
), ),