diff --git a/spacy/_ml.py b/spacy/_ml.py index 2d4064652..8695a88cc 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -25,9 +25,12 @@ from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE from .errors import Errors, user_warning, Warnings from . import util from . import ml as new_ml +from .ml import _legacy_tok2vec VECTORS_KEY = "spacy_pretrained_vectors" +# Backwards compatibility with <2.2.2 +USE_MODEL_REGISTRY_TOK2VEC = False def cosine(vec1, vec2): @@ -315,6 +318,9 @@ def PyTorchBiLSTM(nO, nI, depth, dropout=0.2): def Tok2Vec(width, embed_size, **kwargs): + if not USE_MODEL_REGISTRY_TOK2VEC: + # Preserve prior tok2vec for backwards compat, in v2.2.2 + return _legacy_tok2vec.Tok2Vec(width, embed_size, **kwargs) pretrained_vectors = kwargs.get("pretrained_vectors", None) cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3) subword_features = kwargs.get("subword_features", True) diff --git a/spacy/ml/_legacy_tok2vec.py b/spacy/ml/_legacy_tok2vec.py new file mode 100644 index 000000000..b077a46b7 --- /dev/null +++ b/spacy/ml/_legacy_tok2vec.py @@ -0,0 +1,131 @@ +# coding: utf8 +from __future__ import unicode_literals +from thinc.v2v import Model, Maxout +from thinc.i2v import HashEmbed, StaticVectors +from thinc.t2t import ExtractWindow +from thinc.misc import Residual +from thinc.misc import LayerNorm as LN +from thinc.misc import FeatureExtracter +from thinc.api import layerize, chain, clone, concatenate, with_flatten +from thinc.api import uniqued, wrap, noop + +from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE + + +def Tok2Vec(width, embed_size, **kwargs): + # Circular imports :( + from .._ml import CharacterEmbed + from .._ml import PyTorchBiLSTM + + pretrained_vectors = kwargs.get("pretrained_vectors", None) + cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3) + subword_features = kwargs.get("subword_features", True) + char_embed = kwargs.get("char_embed", False) + if char_embed: + subword_features = False + conv_depth = kwargs.get("conv_depth", 4) + bilstm_depth = kwargs.get("bilstm_depth", 0) + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): + norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm") + if subword_features: + prefix = HashEmbed( + width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix" + ) + suffix = HashEmbed( + width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix" + ) + shape = HashEmbed( + width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape" + ) + else: + prefix, suffix, shape = (None, None, None) + if pretrained_vectors is not None: + glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID)) + + if subword_features: + embed = uniqued( + (glove | norm | prefix | suffix | shape) + >> LN(Maxout(width, width * 5, pieces=3)), + column=cols.index(ORTH), + ) + else: + embed = uniqued( + (glove | norm) >> LN(Maxout(width, width * 2, pieces=3)), + column=cols.index(ORTH), + ) + elif subword_features: + embed = uniqued( + (norm | prefix | suffix | shape) + >> LN(Maxout(width, width * 4, pieces=3)), + column=cols.index(ORTH), + ) + elif char_embed: + embed = concatenate_lists( + CharacterEmbed(nM=64, nC=8), + FeatureExtracter(cols) >> with_flatten(norm), + ) + reduce_dimensions = LN( + Maxout(width, 64 * 8 + width, pieces=cnn_maxout_pieces) + ) + else: + embed = norm + + convolution = Residual( + ExtractWindow(nW=1) + >> LN(Maxout(width, width * 3, pieces=cnn_maxout_pieces)) + ) + if char_embed: + tok2vec = embed >> with_flatten( + reduce_dimensions >> convolution ** conv_depth, pad=conv_depth + ) + else: + tok2vec = FeatureExtracter(cols) >> with_flatten( + embed >> convolution ** conv_depth, pad=conv_depth + ) + + if bilstm_depth >= 1: + tok2vec = tok2vec >> PyTorchBiLSTM(width, width, bilstm_depth) + # Work around thinc API limitations :(. TODO: Revise in Thinc 7 + tok2vec.nO = width + tok2vec.embed = embed + return tok2vec + + +@layerize +def flatten(seqs, drop=0.0): + ops = Model.ops + lengths = ops.asarray([len(seq) for seq in seqs], dtype="i") + + def finish_update(d_X, sgd=None): + return ops.unflatten(d_X, lengths, pad=0) + + X = ops.flatten(seqs, pad=0) + return X, finish_update + + +def concatenate_lists(*layers, **kwargs): # pragma: no cover + """Compose two or more models `f`, `g`, etc, such that their outputs are + concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))` + """ + if not layers: + return noop() + drop_factor = kwargs.get("drop_factor", 1.0) + ops = layers[0].ops + layers = [chain(layer, flatten) for layer in layers] + concat = concatenate(*layers) + + def concatenate_lists_fwd(Xs, drop=0.0): + if drop is not None: + drop *= drop_factor + lengths = ops.asarray([len(X) for X in Xs], dtype="i") + flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop) + ys = ops.unflatten(flat_y, lengths) + + def concatenate_lists_bwd(d_ys, sgd=None): + return bp_flat_y(ops.flatten(d_ys), sgd=sgd) + + return ys, concatenate_lists_bwd + + model = wrap(concatenate_lists_fwd, concat) + return model diff --git a/spacy/ml/tok2vec.py b/spacy/ml/tok2vec.py index 4f3cd458d..6180e2185 100644 --- a/spacy/ml/tok2vec.py +++ b/spacy/ml/tok2vec.py @@ -6,7 +6,6 @@ from thinc.v2v import Maxout, Model from thinc.i2v import HashEmbed, StaticVectors from thinc.t2t import ExtractWindow from thinc.misc import Residual, LayerNorm, FeatureExtracter - from ..util import make_layer, register_architecture from ._wire import concatenate_lists @@ -72,19 +71,20 @@ def MultiHashEmbed(config): ) elif config["@pretrained_vectors"]: mix._layers[0].nI = width * 2 - embed = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),) + layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),) else: - embed = norm + layer = norm layer.cfg = config return layer @register_architecture("spacy.CharacterEmbed.v1") def CharacterEmbed(config): + from .. import _ml width = config["width"] chars = config["chars"] - chr_embed = CharacterEmbed(nM=width, nC=chars) + chr_embed = _ml.CharacterEmbedModel(nM=width, nC=chars) other_tables = make_layer(config["@embed_features"]) mix = make_layer(config["@mix"]) @@ -128,6 +128,7 @@ def PretrainedVectors(config): return StaticVectors(config["vectors_name"], config["width"], config["column"]) + @register_architecture("spacy.TorchBiLSTMEncoder.v1") def TorchBiLSTMEncoder(config): import torch.nn @@ -142,6 +143,9 @@ def TorchBiLSTMEncoder(config): ) + + + _EXAMPLE_CONFIG = { "@doc2feats": { "arch": "Doc2Feats", diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py new file mode 100644 index 000000000..ddaa71059 --- /dev/null +++ b/spacy/tests/test_tok2vec.py @@ -0,0 +1,66 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import pytest + +from spacy._ml import Tok2Vec +from spacy.vocab import Vocab +from spacy.tokens import Doc +from spacy.compat import unicode_ + + +def get_batch(batch_size): + vocab = Vocab() + docs = [] + start = 0 + for size in range(1, batch_size + 1): + # Make the words numbers, so that they're distnct + # across the batch, and easy to track. + numbers = [unicode_(i) for i in range(start, start + size)] + docs.append(Doc(vocab, words=numbers)) + start += size + return docs + + +# This fails in Thinc v7.3.1. Need to push patch +@pytest.mark.xfail +def test_empty_doc(): + width = 128 + embed_size = 2000 + vocab = Vocab() + doc = Doc(vocab, words=[]) + tok2vec = Tok2Vec(width, embed_size) + vectors, backprop = tok2vec.begin_update([doc]) + assert len(vectors) == 1 + assert vectors[0].shape == (0, width) + + +@pytest.mark.parametrize( + "batch_size,width,embed_size", [[1, 128, 2000], [2, 128, 2000], [3, 8, 63]] +) +def test_tok2vec_batch_sizes(batch_size, width, embed_size): + batch = get_batch(batch_size) + tok2vec = Tok2Vec(width, embed_size) + vectors, backprop = tok2vec.begin_update(batch) + assert len(vectors) == len(batch) + for doc_vec, doc in zip(vectors, batch): + assert doc_vec.shape == (len(doc), width) + + +@pytest.mark.parametrize( + "tok2vec_config", + [ + {"width": 8, "embed_size": 100, "char_embed": False}, + {"width": 8, "embed_size": 100, "char_embed": True}, + {"width": 8, "embed_size": 100, "conv_depth": 6}, + {"width": 8, "embed_size": 100, "conv_depth": 6}, + {"width": 8, "embed_size": 100, "subword_features": False}, + ], +) +def test_tok2vec_configs(tok2vec_config): + docs = get_batch(3) + tok2vec = Tok2Vec(**tok2vec_config) + vectors, backprop = tok2vec.begin_update(docs) + assert len(vectors) == len(docs) + assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) + backprop(vectors)