Update tests

This commit is contained in:
Matthew Honnibal 2020-07-29 13:47:37 +02:00
parent 07b47eaac8
commit f0cf4a2dca
2 changed files with 17 additions and 16 deletions

View File

@ -41,7 +41,7 @@ factory = "tagger"
@architectures = "spacy.Tagger.v1" @architectures = "spacy.Tagger.v1"
[components.tagger.model.tok2vec] [components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecTensors.v1" @architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model:width} width = ${components.tok2vec.model:width}
""" """
@ -71,7 +71,7 @@ def my_parser():
tok2vec = build_Tok2Vec_model( tok2vec = build_Tok2Vec_model(
MultiHashEmbed( MultiHashEmbed(
width=321, width=321,
embed_size=5432, rows=5432,
also_embed_subwords=True, also_embed_subwords=True,
also_use_static_vectors=False also_use_static_vectors=False
), ),

View File

@ -1,7 +1,8 @@
import pytest import pytest
from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
@ -60,26 +61,26 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
# fmt: off # fmt: off
@pytest.mark.xfail(reason="TODO: Update for new signature")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"tok2vec_config", "width,embed_arch,embed_config,encode_arch,encode_config",
[ [
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
], ],
) )
# fmt: on # fmt: on
def test_tok2vec_configs(tok2vec_config): def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
embed_config["width"] = width
encode_config["width"] = width
docs = get_batch(3) docs = get_batch(3)
tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config) tok2vec = build_Tok2Vec_model(
embed_arch(**embed_config),
encode_arch(**encode_config)
)
tok2vec.initialize(docs) tok2vec.initialize(docs)
vectors, backprop = tok2vec.begin_update(docs) vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs) assert len(vectors) == len(docs)
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) assert vectors[0].shape == (len(docs[0]), width)
backprop(vectors) backprop(vectors)