mirror of https://github.com/explosion/spaCy.git
Update tests
This commit is contained in:
parent
07b47eaac8
commit
f0cf4a2dca
|
@ -41,7 +41,7 @@ factory = "tagger"
|
||||||
@architectures = "spacy.Tagger.v1"
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
[components.tagger.model.tok2vec]
|
[components.tagger.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${components.tok2vec.model:width}
|
width = ${components.tok2vec.model:width}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ def my_parser():
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
MultiHashEmbed(
|
MultiHashEmbed(
|
||||||
width=321,
|
width=321,
|
||||||
embed_size=5432,
|
rows=5432,
|
||||||
also_embed_subwords=True,
|
also_embed_subwords=True,
|
||||||
also_use_static_vectors=False
|
also_use_static_vectors=False
|
||||||
),
|
),
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||||
from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder
|
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||||
|
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
@ -60,26 +61,26 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@pytest.mark.xfail(reason="TODO: Update for new signature")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tok2vec_config",
|
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||||
[
|
[
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def test_tok2vec_configs(tok2vec_config):
|
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
||||||
|
embed_config["width"] = width
|
||||||
|
encode_config["width"] = width
|
||||||
docs = get_batch(3)
|
docs = get_batch(3)
|
||||||
tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config)
|
tok2vec = build_Tok2Vec_model(
|
||||||
|
embed_arch(**embed_config),
|
||||||
|
encode_arch(**encode_config)
|
||||||
|
)
|
||||||
tok2vec.initialize(docs)
|
tok2vec.initialize(docs)
|
||||||
vectors, backprop = tok2vec.begin_update(docs)
|
vectors, backprop = tok2vec.begin_update(docs)
|
||||||
assert len(vectors) == len(docs)
|
assert len(vectors) == len(docs)
|
||||||
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
|
assert vectors[0].shape == (len(docs[0]), width)
|
||||||
backprop(vectors)
|
backprop(vectors)
|
||||||
|
|
Loading…
Reference in New Issue