From 9614e53b02749e8fec394c0f8a7f965a392918d2 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Mon, 5 Oct 2020 21:55:18 +0200 Subject: [PATCH] Tidy up and auto-format --- spacy/ml/models/tok2vec.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 65d2bffbb..61edb86c4 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union, Dict +from typing import Optional, List, Union from thinc.types import Floats2d from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed @@ -11,7 +11,7 @@ from ...ml import _character_embed from ..staticvectors import StaticVectors from ..featureextractor import FeatureExtractor from ...pipeline.tok2vec import Tok2VecListener -from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr +from ...attrs import intify_attr @registry.architectures.register("spacy.Tok2VecListener.v1") @@ -29,7 +29,7 @@ def build_hash_embed_cnn_tok2vec( window_size: int, maxout_pieces: int, subword_features: bool, - pretrained_vectors: Optional[bool] + pretrained_vectors: Optional[bool], ) -> Model[List[Doc], List[Floats2d]]: """Build spaCy's 'standard' tok2vec layer, which uses hash embedding with subword features and a CNN with layer-normalized maxout. @@ -56,7 +56,7 @@ def build_hash_embed_cnn_tok2vec( """ if subword_features: attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] - row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2] + row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2] else: attrs = ["NORM"] row_sizes = [embed_size] @@ -120,7 +120,7 @@ def MultiHashEmbed( layer is used to map the vectors to the specified width before concatenating it with the other embedding outputs. A single Maxout layer is then used to reduce the concatenated vectors to the final width. - + The `rows` parameter controls the number of rows used by the `HashEmbed` tables. The HashEmbed layer needs surprisingly few rows, due to its use of the hashing trick. Generally between 2000 and 10000 rows is sufficient, @@ -143,13 +143,7 @@ def MultiHashEmbed( def make_hash_embed(index): nonlocal seed seed += 1 - return HashEmbed( - width, - rows[index], - column=index, - seed=seed, - dropout=0.0, - ) + return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0) embeddings = [make_hash_embed(i) for i in range(len(attrs))] concat_size = width * (len(embeddings) + include_static_vectors)