Put Tok2Vec refactor behind feature flag (#4563)

* Add back pre-2.2.2 tok2vec

* Add simple tok2vec tests

* Add simple tok2vec tests

* Reformat

* Fix CharacterEmbed in new tok2vec

* Fix legacy tok2vec

* Resolve circular imports

* Fix test for Python 2
This commit is contained in:
Matthew Honnibal 2019-10-31 15:01:15 +01:00 committed by Ines Montani
parent 828108a57f
commit e82306937e
4 changed files with 211 additions and 4 deletions

View File

@ -25,9 +25,12 @@ from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
from .errors import Errors, user_warning, Warnings
from . import util
from . import ml as new_ml
from .ml import _legacy_tok2vec
VECTORS_KEY = "spacy_pretrained_vectors"
# Backwards compatibility with <2.2.2
USE_MODEL_REGISTRY_TOK2VEC = False
def cosine(vec1, vec2):
@ -315,6 +318,9 @@ def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
def Tok2Vec(width, embed_size, **kwargs):
if not USE_MODEL_REGISTRY_TOK2VEC:
# Preserve prior tok2vec for backwards compat, in v2.2.2
return _legacy_tok2vec.Tok2Vec(width, embed_size, **kwargs)
pretrained_vectors = kwargs.get("pretrained_vectors", None)
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
subword_features = kwargs.get("subword_features", True)

131
spacy/ml/_legacy_tok2vec.py Normal file
View File

@ -0,0 +1,131 @@
# coding: utf8
from __future__ import unicode_literals
from thinc.v2v import Model, Maxout
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual
from thinc.misc import LayerNorm as LN
from thinc.misc import FeatureExtracter
from thinc.api import layerize, chain, clone, concatenate, with_flatten
from thinc.api import uniqued, wrap, noop
from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
def Tok2Vec(width, embed_size, **kwargs):
# Circular imports :(
from .._ml import CharacterEmbed
from .._ml import PyTorchBiLSTM
pretrained_vectors = kwargs.get("pretrained_vectors", None)
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
subword_features = kwargs.get("subword_features", True)
char_embed = kwargs.get("char_embed", False)
if char_embed:
subword_features = False
conv_depth = kwargs.get("conv_depth", 4)
bilstm_depth = kwargs.get("bilstm_depth", 0)
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm")
if subword_features:
prefix = HashEmbed(
width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix"
)
suffix = HashEmbed(
width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix"
)
shape = HashEmbed(
width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape"
)
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID))
if subword_features:
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 5, pieces=3)),
column=cols.index(ORTH),
)
else:
embed = uniqued(
(glove | norm) >> LN(Maxout(width, width * 2, pieces=3)),
column=cols.index(ORTH),
)
elif subword_features:
embed = uniqued(
(norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 4, pieces=3)),
column=cols.index(ORTH),
)
elif char_embed:
embed = concatenate_lists(
CharacterEmbed(nM=64, nC=8),
FeatureExtracter(cols) >> with_flatten(norm),
)
reduce_dimensions = LN(
Maxout(width, 64 * 8 + width, pieces=cnn_maxout_pieces)
)
else:
embed = norm
convolution = Residual(
ExtractWindow(nW=1)
>> LN(Maxout(width, width * 3, pieces=cnn_maxout_pieces))
)
if char_embed:
tok2vec = embed >> with_flatten(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtracter(cols) >> with_flatten(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchBiLSTM(width, width, bilstm_depth)
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width
tok2vec.embed = embed
return tok2vec
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=0)
X = ops.flatten(seqs, pad=0)
return X, finish_update
def concatenate_lists(*layers, **kwargs): # pragma: no cover
"""Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
"""
if not layers:
return noop()
drop_factor = kwargs.get("drop_factor", 1.0)
ops = layers[0].ops
layers = [chain(layer, flatten) for layer in layers]
concat = concatenate(*layers)
def concatenate_lists_fwd(Xs, drop=0.0):
if drop is not None:
drop *= drop_factor
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
ys = ops.unflatten(flat_y, lengths)
def concatenate_lists_bwd(d_ys, sgd=None):
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat)
return model

View File

@ -6,7 +6,6 @@ from thinc.v2v import Maxout, Model
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual, LayerNorm, FeatureExtracter
from ..util import make_layer, register_architecture
from ._wire import concatenate_lists
@ -72,19 +71,20 @@ def MultiHashEmbed(config):
)
elif config["@pretrained_vectors"]:
mix._layers[0].nI = width * 2
embed = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),)
layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),)
else:
embed = norm
layer = norm
layer.cfg = config
return layer
@register_architecture("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
from .. import _ml
width = config["width"]
chars = config["chars"]
chr_embed = CharacterEmbed(nM=width, nC=chars)
chr_embed = _ml.CharacterEmbedModel(nM=width, nC=chars)
other_tables = make_layer(config["@embed_features"])
mix = make_layer(config["@mix"])
@ -128,6 +128,7 @@ def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"])
@register_architecture("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
@ -142,6 +143,9 @@ def TorchBiLSTMEncoder(config):
)
_EXAMPLE_CONFIG = {
"@doc2feats": {
"arch": "Doc2Feats",

View File

@ -0,0 +1,66 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from spacy._ml import Tok2Vec
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.compat import unicode_
def get_batch(batch_size):
vocab = Vocab()
docs = []
start = 0
for size in range(1, batch_size + 1):
# Make the words numbers, so that they're distnct
# across the batch, and easy to track.
numbers = [unicode_(i) for i in range(start, start + size)]
docs.append(Doc(vocab, words=numbers))
start += size
return docs
# This fails in Thinc v7.3.1. Need to push patch
@pytest.mark.xfail
def test_empty_doc():
width = 128
embed_size = 2000
vocab = Vocab()
doc = Doc(vocab, words=[])
tok2vec = Tok2Vec(width, embed_size)
vectors, backprop = tok2vec.begin_update([doc])
assert len(vectors) == 1
assert vectors[0].shape == (0, width)
@pytest.mark.parametrize(
"batch_size,width,embed_size", [[1, 128, 2000], [2, 128, 2000], [3, 8, 63]]
)
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
batch = get_batch(batch_size)
tok2vec = Tok2Vec(width, embed_size)
vectors, backprop = tok2vec.begin_update(batch)
assert len(vectors) == len(batch)
for doc_vec, doc in zip(vectors, batch):
assert doc_vec.shape == (len(doc), width)
@pytest.mark.parametrize(
"tok2vec_config",
[
{"width": 8, "embed_size": 100, "char_embed": False},
{"width": 8, "embed_size": 100, "char_embed": True},
{"width": 8, "embed_size": 100, "conv_depth": 6},
{"width": 8, "embed_size": 100, "conv_depth": 6},
{"width": 8, "embed_size": 100, "subword_features": False},
],
)
def test_tok2vec_configs(tok2vec_config):
docs = get_batch(3)
tok2vec = Tok2Vec(**tok2vec_config)
vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs)
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
backprop(vectors)