Fix tok2vec

This commit is contained in:
Matthew Honnibal 2020-07-28 15:51:40 +02:00
parent fe0cdcd461
commit 099e9331c5
1 changed files with 28 additions and 17 deletions

View File

@ -1,13 +1,15 @@
from typing import Optional, List from typing import Optional, List
from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop from thinc.api import Model, noop, list2ragged, ragged2list
from thinc.api import FeatureExtractor, HashEmbed, StaticVectors from thinc.api import FeatureExtractor, HashEmbed
from thincapi import expand_window, residual, Maxout, Mish from thinc.api import expand_window, residual, Maxout, Mish
from thinc.types import Floats2d from thinc.types import Floats2d
from ...tokens import Doc
from ... import util from ... import util
from ...util import registry from ...util import registry
from ...ml import _character_embed from ...ml import _character_embed
from ..staticvectors import StaticVectors
from ...pipeline.tok2vec import Tok2VecListener from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
@ -21,20 +23,19 @@ def tok2vec_listener_v1(width, upstream="*"):
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec( def Tok2Vec(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d] encode: Model[List[Floats2d], List[Floats2d]]
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
tok2vec = with_array(
chain(embed, encode), receptive_field = encode.attrs.get("receptive_field", 0)
pad=encode.attrs.get("receptive_field", 0) tok2vec = chain(embed, with_array(encode, pad=receptive_field))
)
tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed) tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode) tok2vec.set_ref("encode", encode)
return tok2vec return tok2vec
@registry.architectures.register("spacy.HashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
def HashEmbed( def MultiHashEmbed(
width: int, width: int,
rows: int, rows: int,
also_embed_subwords: bool, also_embed_subwords: bool,
@ -56,9 +57,9 @@ def HashEmbed(
if also_embed_subwords: if also_embed_subwords:
embeddings = [ embeddings = [
make_hash_embed(NORM) make_hash_embed(NORM),
make_hash_embed(PREFIX) make_hash_embed(PREFIX),
make_hash_embed(SUFFIX) make_hash_embed(SUFFIX),
make_hash_embed(SHAPE) make_hash_embed(SHAPE)
] ]
else: else:
@ -67,15 +68,25 @@ def HashEmbed(
if also_use_static_vectors: if also_use_static_vectors:
model = chain( model = chain(
concatenate( concatenate(
chain(FeatureExtractor(cols), concatenate(*embeddings)), chain(
FeatureExtractor(cols),
list2ragged(),
with_array(concatenate(*embeddings))
),
StaticVectors(width, dropout=0.0) StaticVectors(width, dropout=0.0)
), ),
Maxout(width, pieces=3, dropout=0.0, normalize=True) with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list()
) )
else: else:
model = chain( model = chain(
chain(FeatureExtractor(cols), concatenate(*embeddings)), chain(
Maxout(width, pieces=3, dropout=0.0, normalize=True) FeatureExtractor(cols),
list2ragged(),
with_array(concatenate(*embeddings))
),
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list()
) )
return model return model