diff --git a/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg b/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg index 4f1a915c5..b6b4e82b6 100644 --- a/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg +++ b/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg @@ -62,4 +62,4 @@ width = 96 depth = 4 embed_size = 2000 subword_features = true -char_embed = false +maxout_pieces = 3 diff --git a/examples/experiments/tok2vec-ner/charembed_tok2vec.cfg b/examples/experiments/tok2vec-ner/charembed_tok2vec.cfg new file mode 100644 index 000000000..b8219ad10 --- /dev/null +++ b/examples/experiments/tok2vec-ner/charembed_tok2vec.cfg @@ -0,0 +1,65 @@ +[training] +use_gpu = -1 +limit = 0 +dropout = 0.2 +patience = 10000 +eval_frequency = 200 +scores = ["ents_f"] +score_weights = {"ents_f": 1} +orth_variant_level = 0.0 +gold_preproc = true +max_length = 0 +batch_size = 25 + +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.001 +beta1 = 0.9 +beta2 = 0.999 + +[nlp] +lang = "en" +vectors = null + +[nlp.pipeline.tok2vec] +factory = "tok2vec" + +[nlp.pipeline.tok2vec.model] +@architectures = "spacy.Tok2Vec.v1" + +[nlp.pipeline.tok2vec.model.extract] +@architectures = "spacy.CharacterEmbed.v1" +width = 96 +nM = 64 +nC = 8 +rows = 2000 +columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"] + +[nlp.pipeline.tok2vec.model.extract.features] +@architectures = "spacy.Doc2Feats.v1" +columns = ${nlp.pipeline.tok2vec.model.extract:columns} + +[nlp.pipeline.tok2vec.model.embed] +@architectures = "spacy.LayerNormalizedMaxout.v1" +width = ${nlp.pipeline.tok2vec.model.extract:width} +maxout_pieces = 4 + +[nlp.pipeline.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncoder.v1" +width = ${nlp.pipeline.tok2vec.model.extract:width} +window_size = 1 +maxout_pieces = 2 +depth = 2 + +[nlp.pipeline.ner] +factory = "ner" + +[nlp.pipeline.ner.model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 6 +hidden_width = 64 +maxout_pieces = 2 + +[nlp.pipeline.ner.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model.extract:width} diff --git a/examples/experiments/tok2vec-ner/multihashembed_tok2vec.cfg b/examples/experiments/tok2vec-ner/multihashembed_tok2vec.cfg new file mode 100644 index 000000000..4678a7d6b --- /dev/null +++ b/examples/experiments/tok2vec-ner/multihashembed_tok2vec.cfg @@ -0,0 +1,65 @@ +[training] +use_gpu = -1 +limit = 0 +dropout = 0.2 +patience = 10000 +eval_frequency = 200 +scores = ["ents_f"] +score_weights = {"ents_f": 1} +orth_variant_level = 0.0 +gold_preproc = true +max_length = 0 +batch_size = 25 + +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.001 +beta1 = 0.9 +beta2 = 0.999 + +[nlp] +lang = "en" +vectors = null + +[nlp.pipeline.tok2vec] +factory = "tok2vec" + +[nlp.pipeline.tok2vec.model] +@architectures = "spacy.Tok2Vec.v1" + +[nlp.pipeline.tok2vec.model.extract] +@architectures = "spacy.Doc2Feats.v1" +columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"] + +[nlp.pipeline.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v1" +columns = ${nlp.pipeline.tok2vec.model.extract:columns} +width = 96 +rows = 2000 +use_subwords = true +pretrained_vectors = null + +[nlp.pipeline.tok2vec.model.embed.mix] +@architectures = "spacy.LayerNormalizedMaxout.v1" +width = ${nlp.pipeline.tok2vec.model.embed:width} +maxout_pieces = 3 + +[nlp.pipeline.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncoder.v1" +width = ${nlp.pipeline.tok2vec.model.embed:width} +window_size = 1 +maxout_pieces = 3 +depth = 2 + +[nlp.pipeline.ner] +factory = "ner" + +[nlp.pipeline.ner.model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 6 +hidden_width = 64 +maxout_pieces = 2 + +[nlp.pipeline.ner.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model.embed:width} diff --git a/spacy/language.py b/spacy/language.py index d0077b9d2..20e29c829 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -337,13 +337,14 @@ class Language(object): default_config = self.defaults.get(name, None) # transform the model's config to an actual Model + factory_cfg = dict(config) model_cfg = None - if "model" in config: - model_cfg = config["model"] + if "model" in factory_cfg: + model_cfg = factory_cfg["model"] if not isinstance(model_cfg, dict): warnings.warn(Warnings.W099.format(type=type(model_cfg), pipe=name)) model_cfg = None - del config["model"] + del factory_cfg["model"] if model_cfg is None and default_config is not None: warnings.warn(Warnings.W098.format(name=name)) model_cfg = default_config["model"] @@ -353,7 +354,7 @@ class Language(object): model = registry.make_from_config({"model": model_cfg}, validate=True)[ "model" ] - return factory(self, model, **config) + return factory(self, model, **factory_cfg) def add_pipe( self, component, name=None, before=None, after=None, first=None, last=None diff --git a/spacy/ml/_character_embed.py b/spacy/ml/_character_embed.py index b366f67c6..f4890144a 100644 --- a/spacy/ml/_character_embed.py +++ b/spacy/ml/_character_embed.py @@ -21,7 +21,7 @@ def init(model, X=None, Y=None): def forward(model, docs, is_train): - if not docs: + if docs is None: return [] ids = [] output = [] diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 0d33d010d..d1a98c080 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -4,7 +4,7 @@ from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM from thinc.api import residual, LayerNorm, FeatureExtractor, Mish from ... import util -from ...util import registry, make_layer +from ...util import registry from ...ml import _character_embed from ...pipeline.tok2vec import Tok2VecListener from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE @@ -23,15 +23,14 @@ def get_vocab_vectors(name): @registry.architectures.register("spacy.Tok2Vec.v1") -def Tok2Vec(config): - doc2feats = make_layer(config["@doc2feats"]) - embed = make_layer(config["@embed"]) - encode = make_layer(config["@encode"]) +def Tok2Vec(extract, embed, encode): field_size = 0 - if encode.has_attr("receptive_field"): + if encode.attrs.get("receptive_field", None): field_size = encode.attrs["receptive_field"] - tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size)) - tok2vec.attrs["cfg"] = config + with Model.define_operators({">>": chain, "|": concatenate}): + if extract.has_dim("nO"): + _set_dims(embed, "nI", extract.get_dim("nO")) + tok2vec = extract >> with_array(embed >> encode, pad=field_size) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) tok2vec.set_ref("encode", encode) @@ -39,8 +38,7 @@ def Tok2Vec(config): @registry.architectures.register("spacy.Doc2Feats.v1") -def Doc2Feats(config): - columns = config["columns"] +def Doc2Feats(columns): return FeatureExtractor(columns) @@ -79,8 +77,8 @@ def hash_charembed_cnn( maxout_pieces, window_size, subword_features, - nM=0, - nC=0, + nM, + nC, ): # Allows using character embeddings by setting nC, nM and char_embed=True return build_Tok2Vec_model( @@ -100,7 +98,7 @@ def hash_charembed_cnn( @registry.architectures.register("spacy.HashEmbedBiLSTM.v1") def hash_embed_bilstm_v1( - pretrained_vectors, width, depth, embed_size, subword_features + pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces ): # Does not use character embeddings: set to False by default return build_Tok2Vec_model( @@ -109,7 +107,7 @@ def hash_embed_bilstm_v1( pretrained_vectors=pretrained_vectors, bilstm_depth=depth, conv_depth=0, - maxout_pieces=0, + maxout_pieces=maxout_pieces, window_size=1, subword_features=subword_features, char_embed=False, @@ -120,7 +118,7 @@ def hash_embed_bilstm_v1( @registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") def hash_char_embed_bilstm_v1( - pretrained_vectors, width, depth, embed_size, subword_features, nM=0, nC=0 + pretrained_vectors, width, depth, embed_size, subword_features, nM, nC, maxout_pieces ): # Allows using character embeddings by setting nC, nM and char_embed=True return build_Tok2Vec_model( @@ -129,7 +127,7 @@ def hash_char_embed_bilstm_v1( pretrained_vectors=pretrained_vectors, bilstm_depth=depth, conv_depth=0, - maxout_pieces=0, + maxout_pieces=maxout_pieces, window_size=1, subword_features=subword_features, char_embed=True, @@ -138,104 +136,99 @@ def hash_char_embed_bilstm_v1( ) -@registry.architectures.register("spacy.MultiHashEmbed.v1") -def MultiHashEmbed(config): - # For backwards compatibility with models before the architecture registry, - # we have to be careful to get exactly the same model structure. One subtle - # trick is that when we define concatenation with the operator, the operator - # is actually binary associative. So when we write (a | b | c), we're actually - # getting concatenate(concatenate(a, b), c). That's why the implementation - # is a bit ugly here. - cols = config["columns"] - width = config["width"] - rows = config["rows"] +@registry.architectures.register("spacy.LayerNormalizedMaxout.v1") +def LayerNormalizedMaxout(width, maxout_pieces): + return Maxout( + nO=width, + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ) - norm = HashEmbed(width, rows, column=cols.index("NORM")) - if config["use_subwords"]: - prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX")) - suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX")) - shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE")) - if config.get("@pretrained_vectors"): - glove = make_layer(config["@pretrained_vectors"]) - mix = make_layer(config["@mix"]) + +@registry.architectures.register("spacy.MultiHashEmbed.v1") +def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix): + norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM")) + if use_subwords: + prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX")) + suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX")) + shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE")) + + if pretrained_vectors: + glove = StaticVectors( + vectors=pretrained_vectors.data, + nO=width, + column=columns.index(ID), + dropout=0.0, + ) with Model.define_operators({">>": chain, "|": concatenate}): - if config["use_subwords"] and config["@pretrained_vectors"]: - mix._layers[0].set_dim("nI", width * 5) - layer = uniqued( - (glove | norm | prefix | suffix | shape) >> mix, - column=cols.index("ORTH"), - ) - elif config["use_subwords"]: - mix._layers[0].set_dim("nI", width * 4) - layer = uniqued( - (norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH") - ) - elif config["@pretrained_vectors"]: - mix._layers[0].set_dim("nI", width * 2) - layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH")) + if not use_subwords and not pretrained_vectors: + embed_layer = norm else: - layer = norm - layer.attrs["cfg"] = config - return layer + if use_subwords and pretrained_vectors: + nr_columns = 5 + concat_columns = glove | norm | prefix | suffix | shape + elif use_subwords: + nr_columns = 4 + concat_columns = norm | prefix | suffix | shape + else: + nr_columns = 2 + concat_columns = glove | norm + _set_dims(mix, "nI", width * nr_columns) + embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) + + return embed_layer + + +def _set_dims(model, name, value): + # Loop through the model to set a specific dimension if its unset on any layer. + for node in model.walk(): + if node.has_dim(name) is None: + node.set_dim(name, value) @registry.architectures.register("spacy.CharacterEmbed.v1") -def CharacterEmbed(config): - width = config["width"] - chars = config["chars"] - - chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars) - other_tables = make_layer(config["@embed_features"]) - mix = make_layer(config["@mix"]) - - model = chain(concatenate(chr_embed, other_tables), mix) - model.attrs["cfg"] = config - return model +def CharacterEmbed(columns, width, rows, nM, nC, features): + norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM")) + chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) + with Model.define_operators({">>": chain, "|": concatenate}): + embed_layer = chr_embed | features >> with_array(norm) + embed_layer.set_dim("nO", nM * nC + width) + return embed_layer @registry.architectures.register("spacy.MaxoutWindowEncoder.v1") -def MaxoutWindowEncoder(config): - nO = config["width"] - nW = config["window_size"] - nP = config["pieces"] - depth = config["depth"] - - cnn = ( - expand_window(window_size=nW), - Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True), +def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth): + cnn = chain( + expand_window(window_size=window_size), + Maxout(nO=width, nI=width * ((window_size * 2) + 1), nP=maxout_pieces, dropout=0.0, normalize=True), ) model = clone(residual(cnn), depth) - model.set_dim("nO", nO) - model.attrs["receptive_field"] = nW * depth + model.set_dim("nO", width) + model.attrs["receptive_field"] = window_size * depth return model @registry.architectures.register("spacy.MishWindowEncoder.v1") -def MishWindowEncoder(config): - nO = config["width"] - nW = config["window_size"] - depth = config["depth"] - +def MishWindowEncoder(width, window_size, depth): cnn = chain( - expand_window(window_size=nW), - Mish(nO=nO, nI=nO * ((nW * 2) + 1)), - LayerNorm(nO), + expand_window(window_size=window_size), + Mish(nO=width, nI=width * ((window_size * 2) + 1)), + LayerNorm(width), ) model = clone(residual(cnn), depth) - model.set_dim("nO", nO) + model.set_dim("nO", width) return model @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") -def TorchBiLSTMEncoder(config): +def TorchBiLSTMEncoder(width, depth): import torch.nn # TODO FIX from thinc.api import PyTorchRNNWrapper - width = config["width"] - depth = config["depth"] if depth == 0: return noop() return with_padded( @@ -243,40 +236,6 @@ def TorchBiLSTMEncoder(config): ) -# TODO: update -_EXAMPLE_CONFIG = { - "@doc2feats": { - "arch": "Doc2Feats", - "config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]}, - }, - "@embed": { - "arch": "spacy.MultiHashEmbed.v1", - "config": { - "width": 96, - "rows": 2000, - "columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"], - "use_subwords": True, - "@pretrained_vectors": { - "arch": "TransformedStaticVectors", - "config": { - "vectors_name": "en_vectors_web_lg.vectors", - "width": 96, - "column": 0, - }, - }, - "@mix": { - "arch": "LayerNormalizedMaxout", - "config": {"width": 96, "pieces": 3}, - }, - }, - }, - "@encode": { - "arch": "MaxoutWindowEncode", - "config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3}, - }, -} - - def build_Tok2Vec_model( width, embed_size, diff --git a/spacy/ml/tok2vec.py b/spacy/ml/tok2vec.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 2fee6881a..4623f99b0 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -131,9 +131,10 @@ class Tok2Vec(Pipe): get_examples (function): Function returning example training data. pipeline (list): The pipeline the model is part of. """ - # TODO: use examples instead ? - docs = [Doc(Vocab(), words=["hello"])] - self.model.initialize(X=docs) + # TODO: charembed does not play nicely with dim inference yet + # docs = [Doc(Vocab(), words=["hello"])] + # self.model.initialize(X=docs) + self.model.initialize() link_vectors_to_models(self.vocab) diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py index 7a929a6a2..411768e5f 100644 --- a/spacy/tests/pipeline/test_senter.py +++ b/spacy/tests/pipeline/test_senter.py @@ -36,17 +36,17 @@ def test_overfitting_IO(): assert losses["senter"] < 0.0001 # test the trained model - test_text = "I like eggs. There is ham. She likes ham." + test_text = "I like purple eggs. They eat ham. You like yellow eggs." doc = nlp(test_text) - gold_sent_starts = [0] * 12 + gold_sent_starts = [0] * 14 gold_sent_starts[0] = 1 - gold_sent_starts[4] = 1 - gold_sent_starts[8] = 1 - assert gold_sent_starts == [int(t.is_sent_start) for t in doc] + gold_sent_starts[5] = 1 + gold_sent_starts[9] = 1 + assert [int(t.is_sent_start) for t in doc] == gold_sent_starts # Also test the results are still the same after IO with make_tempdir() as tmp_dir: nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) doc2 = nlp2(test_text) - assert gold_sent_starts == [int(t.is_sent_start) for t in doc2] + assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts diff --git a/spacy/util.py b/spacy/util.py index 216158e52..37649c5e6 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -79,11 +79,6 @@ def set_lang_class(name, cls): registry.languages.register(name, func=cls) -def make_layer(arch_config): - arch_func = registry.architectures.get(arch_config["arch"]) - return arch_func(arch_config["config"]) - - def ensure_path(path): """Ensure string is converted to a Path. @@ -563,7 +558,7 @@ def minibatch_by_words(examples, size, tuples=True, count_words=len): """Create minibatches of a given number of words.""" if isinstance(size, int): size_ = itertools.repeat(size) - if isinstance(size, List): + elif isinstance(size, List): size_ = iter(size) else: size_ = size