mirror of https://github.com/explosion/spaCy.git
Avoid relying on NORM in default v3 models (#6176)
* Allow CharacterEmbed to specify feature * Default to LOWER in character embed * Update tok2vec * Use LOWER, not NORM
This commit is contained in:
parent
50162b8726
commit
300e5a9928
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Union
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||||
|
@ -10,7 +10,7 @@ from ...ml import _character_embed
|
||||||
from ..staticvectors import StaticVectors
|
from ..staticvectors import StaticVectors
|
||||||
from ..featureextractor import FeatureExtractor
|
from ..featureextractor import FeatureExtractor
|
||||||
from ...pipeline.tok2vec import Tok2VecListener
|
from ...pipeline.tok2vec import Tok2VecListener
|
||||||
from ...attrs import ORTH, NORM, PREFIX, SUFFIX, SHAPE
|
from ...attrs import ORTH, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||||
|
@ -98,7 +98,7 @@ def MultiHashEmbed(
|
||||||
attributes using hash embedding, concatenates the results, and passes it
|
attributes using hash embedding, concatenates the results, and passes it
|
||||||
through a feed-forward subnetwork to build a mixed representations.
|
through a feed-forward subnetwork to build a mixed representations.
|
||||||
|
|
||||||
The features used are the NORM, PREFIX, SUFFIX and SHAPE, which can have
|
The features used are the LOWER, PREFIX, SUFFIX and SHAPE, which can have
|
||||||
varying definitions depending on the Vocab of the Doc object passed in.
|
varying definitions depending on the Vocab of the Doc object passed in.
|
||||||
Vectors from pretrained static vectors can also be incorporated into the
|
Vectors from pretrained static vectors can also be incorporated into the
|
||||||
concatenated representation.
|
concatenated representation.
|
||||||
|
@ -115,7 +115,7 @@ def MultiHashEmbed(
|
||||||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||||
"""
|
"""
|
||||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
cols = [LOWER, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
seed = 7
|
seed = 7
|
||||||
|
|
||||||
def make_hash_embed(feature):
|
def make_hash_embed(feature):
|
||||||
|
@ -123,7 +123,7 @@ def MultiHashEmbed(
|
||||||
seed += 1
|
seed += 1
|
||||||
return HashEmbed(
|
return HashEmbed(
|
||||||
width,
|
width,
|
||||||
rows if feature == NORM else rows // 2,
|
rows if feature == LOWER else rows // 2,
|
||||||
column=cols.index(feature),
|
column=cols.index(feature),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
|
@ -131,13 +131,13 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
if also_embed_subwords:
|
if also_embed_subwords:
|
||||||
embeddings = [
|
embeddings = [
|
||||||
make_hash_embed(NORM),
|
make_hash_embed(LOWER),
|
||||||
make_hash_embed(PREFIX),
|
make_hash_embed(PREFIX),
|
||||||
make_hash_embed(SUFFIX),
|
make_hash_embed(SUFFIX),
|
||||||
make_hash_embed(SHAPE),
|
make_hash_embed(SHAPE),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
embeddings = [make_hash_embed(NORM)]
|
embeddings = [make_hash_embed(LOWER)]
|
||||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||||
if also_use_static_vectors:
|
if also_use_static_vectors:
|
||||||
model = chain(
|
model = chain(
|
||||||
|
@ -165,7 +165,8 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(
|
def CharacterEmbed(
|
||||||
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool
|
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
|
||||||
|
feature: Union[int, str]="LOWER"
|
||||||
):
|
):
|
||||||
"""Construct an embedded representation based on character embeddings, using
|
"""Construct an embedded representation based on character embeddings, using
|
||||||
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
||||||
|
@ -179,12 +180,13 @@ def CharacterEmbed(
|
||||||
of being in an arbitrary position depending on the word length.
|
of being in an arbitrary position depending on the word length.
|
||||||
|
|
||||||
The characters are embedded in a embedding table with a given number of rows,
|
The characters are embedded in a embedding table with a given number of rows,
|
||||||
and the vectors concatenated. A hash-embedded vector of the NORM of the word is
|
and the vectors concatenated. A hash-embedded vector of the LOWER of the word is
|
||||||
also concatenated on, and the result is then passed through a feed-forward
|
also concatenated on, and the result is then passed through a feed-forward
|
||||||
network to construct a single vector to represent the information.
|
network to construct a single vector to represent the information.
|
||||||
|
|
||||||
width (int): The width of the output vector and the NORM hash embedding.
|
feature (int or str): An attribute to embed, to concatenate with the characters.
|
||||||
rows (int): The number of rows in the NORM hash embedding table.
|
width (int): The width of the output vector and the feature embedding.
|
||||||
|
rows (int): The number of rows in the LOWER hash embedding table.
|
||||||
nM (int): The dimensionality of the character embeddings. Recommended values
|
nM (int): The dimensionality of the character embeddings. Recommended values
|
||||||
are between 16 and 64.
|
are between 16 and 64.
|
||||||
nC (int): The number of UTF-8 bytes to embed per word. Recommended values
|
nC (int): The number of UTF-8 bytes to embed per word. Recommended values
|
||||||
|
@ -193,12 +195,15 @@ def CharacterEmbed(
|
||||||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||||
"""
|
"""
|
||||||
|
feature = intify_attr(feature)
|
||||||
|
if feature is None:
|
||||||
|
raise ValueError("Invalid feature: Must be a token attribute.")
|
||||||
if also_use_static_vectors:
|
if also_use_static_vectors:
|
||||||
model = chain(
|
model = chain(
|
||||||
concatenate(
|
concatenate(
|
||||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor([NORM]),
|
FeatureExtractor([feature]),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||||
),
|
),
|
||||||
|
@ -214,7 +219,7 @@ def CharacterEmbed(
|
||||||
concatenate(
|
concatenate(
|
||||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor([NORM]),
|
FeatureExtractor([feature]),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue