Allow CharacterEmbed to specify feature

This commit is contained in:
Matthew Honnibal 2020-10-01 22:17:26 +02:00
parent 0a8a124a6e
commit 684a77870b
1 changed files with 10 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Union
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list
from thinc.api import FeatureExtractor, HashEmbed
@ -165,7 +165,8 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
feature: Union[int, str]="NORM"
):
"""Construct an embedded representation based on character embeddings, using
a feed-forward network. A fixed number of UTF-8 byte characters are used for
@ -183,7 +184,8 @@ def CharacterEmbed(
also concatenated on, and the result is then passed through a feed-forward
network to construct a single vector to represent the information.
width (int): The width of the output vector and the NORM hash embedding.
feature (int or str): An attribute to embed, to concatenate with the characters.
width (int): The width of the output vector and the feature embedding.
rows (int): The number of rows in the NORM hash embedding table.
nM (int): The dimensionality of the character embeddings. Recommended values
are between 16 and 64.
@ -193,12 +195,15 @@ def CharacterEmbed(
also_use_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab.
"""
feature = intify_attr(feature)
if feature is None:
raise ValueError("Invalid feature: Must be a token attribute.")
if also_use_static_vectors:
model = chain(
concatenate(
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
chain(
FeatureExtractor([NORM]),
FeatureExtractor([feature]),
list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
),
@ -214,7 +219,7 @@ def CharacterEmbed(
concatenate(
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
chain(
FeatureExtractor([NORM]),
FeatureExtractor([feature]),
list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
),