diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 4c4bd0d22..448f9d1d0 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,13 +1,15 @@ from typing import Optional, List from thinc.api import chain, clone, concatenate, with_array, with_padded -from thinc.api import Model, noop -from thinc.api import FeatureExtractor, HashEmbed, StaticVectors -from thincapi import expand_window, residual, Maxout, Mish +from thinc.api import Model, noop, list2ragged, ragged2list +from thinc.api import FeatureExtractor, HashEmbed +from thinc.api import expand_window, residual, Maxout, Mish from thinc.types import Floats2d +from ...tokens import Doc from ... import util from ...util import registry from ...ml import _character_embed +from ..staticvectors import StaticVectors from ...pipeline.tok2vec import Tok2VecListener from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE @@ -21,20 +23,19 @@ def tok2vec_listener_v1(width, upstream="*"): @registry.architectures.register("spacy.Tok2Vec.v1") def Tok2Vec( embed: Model[List[Doc], List[Floats2d]], - encode: Model[List[Floats2d], List[Floats2d] + encode: Model[List[Floats2d], List[Floats2d]] ) -> Model[List[Doc], List[Floats2d]]: - tok2vec = with_array( - chain(embed, encode), - pad=encode.attrs.get("receptive_field", 0) - ) + + receptive_field = encode.attrs.get("receptive_field", 0) + tok2vec = chain(embed, with_array(encode, pad=receptive_field)) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) tok2vec.set_ref("encode", encode) return tok2vec -@registry.architectures.register("spacy.HashEmbed.v1") -def HashEmbed( +@registry.architectures.register("spacy.MultiHashEmbed.v1") +def MultiHashEmbed( width: int, rows: int, also_embed_subwords: bool, @@ -56,9 +57,9 @@ def HashEmbed( if also_embed_subwords: embeddings = [ - make_hash_embed(NORM) - make_hash_embed(PREFIX) - make_hash_embed(SUFFIX) + make_hash_embed(NORM), + make_hash_embed(PREFIX), + make_hash_embed(SUFFIX), make_hash_embed(SHAPE) ] else: @@ -67,15 +68,25 @@ def HashEmbed( if also_use_static_vectors: model = chain( concatenate( - chain(FeatureExtractor(cols), concatenate(*embeddings)), + chain( + FeatureExtractor(cols), + list2ragged(), + with_array(concatenate(*embeddings)) + ), StaticVectors(width, dropout=0.0) ), - Maxout(width, pieces=3, dropout=0.0, normalize=True) + with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + ragged2list() ) else: model = chain( - chain(FeatureExtractor(cols), concatenate(*embeddings)), - Maxout(width, pieces=3, dropout=0.0, normalize=True) + chain( + FeatureExtractor(cols), + list2ragged(), + with_array(concatenate(*embeddings)) + ), + with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + ragged2list() ) return model