From f3db3f6fe00455f69bf05135f941ba88d307738b Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Wed, 16 Sep 2020 17:45:04 +0200 Subject: [PATCH] Add vectors option to CharacterEmbed (#6069) * Add vectors option to CharacterEmbed * Update spacy/pipeline/morphologizer.pyx * Adjust default morphologizer config Co-authored-by: Matthew Honnibal --- spacy/ml/models/tok2vec.py | 39 +++++++++++++++++++++++--------- spacy/pipeline/morphologizer.pyx | 1 + spacy/tests/test_tok2vec.py | 4 ++-- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 2e5f8a802..7ced4bd04 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -164,7 +164,7 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") -def CharacterEmbed(width: int, rows: int, nM: int, nC: int): +def CharacterEmbed(width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool): """Construct an embedded representation based on character embeddings, using a feed-forward network. A fixed number of UTF-8 byte characters are used for each word, taken from the beginning and end of the word equally. Padding is @@ -188,18 +188,35 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int): nC (int): The number of UTF-8 bytes to embed per word. Recommended values are between 3 and 8, although it may depend on the length of words in the language. + also_use_static_vectors (bool): Whether to also use static word vectors. + Requires a vectors table to be loaded in the Doc objects' vocab. """ - model = chain( - concatenate( - chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), - chain( - FeatureExtractor([NORM]), - list2ragged(), - with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), + if also_use_static_vectors: + model = chain( + concatenate( + chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), + chain( + FeatureExtractor([NORM]), + list2ragged(), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), + ), + StaticVectors(width, dropout=0.0), ), - ), - with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)), - ragged2list(), + with_array(Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)), + ragged2list(), + ) + else: + model = chain( + concatenate( + chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), + chain( + FeatureExtractor([NORM]), + list2ragged(), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), + ), + ), + with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)), + ragged2list(), ) return model diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 0e0791004..bb68a358c 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -32,6 +32,7 @@ width = 128 rows = 7000 nM = 64 nC = 8 +also_use_static_vectors = false [model.tok2vec.encode] @architectures = "spacy.MaxoutWindowEncoder.v1" diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py index fb30c6ae5..f3f35e4a7 100644 --- a/spacy/tests/test_tok2vec.py +++ b/spacy/tests/test_tok2vec.py @@ -63,8 +63,8 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): [ (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}), (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}), - (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), - (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}), + (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), + (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}), ], ) # fmt: on