From 9c064e6ad95ac2733fc27a874644b6ad8caecedf Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Wed, 12 Jun 2019 04:29:23 -0700 Subject: [PATCH] Add resume logic to spacy pretrain (#3652) * Added ability to resume training * Add to readmee * Remove duplicate entry --- spacy/cli/pretrain.py | 15 +++++++++++++++ website/docs/api/cli.md | 1 + 2 files changed, 16 insertions(+) diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 3b2981f0d..18fe0598f 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -18,6 +18,7 @@ from ..attrs import ID, HEAD from .._ml import Tok2Vec, flatten, chain, create_default_optimizer from .._ml import masked_language_model from .. import util +from .train import _load_pretrained_tok2vec @plac.annotations( @@ -36,6 +37,12 @@ from .. import util seed=("Seed for random number generators", "option", "s", int), n_iter=("Number of iterations to pretrain", "option", "i", int), n_save_every=("Save model every X batches.", "option", "se", int), + init_tok2vec=( + "Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.", + "option", + "t2v", + Path, + ), ) def pretrain( texts_loc, @@ -53,6 +60,7 @@ def pretrain( min_length=5, seed=0, n_save_every=None, + init_tok2vec=None, ): """ Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components, @@ -70,6 +78,9 @@ def pretrain( errors around this need some improvement. """ config = dict(locals()) + for key in config: + if isinstance(config[key], Path): + config[key] = str(config[key]) msg = Printer() util.fix_random_seed(seed) @@ -112,6 +123,10 @@ def pretrain( subword_features=True, # Set to False for Chinese etc ), ) + # Load in pre-trained weights + if init_tok2vec is not None: + components = _load_pretrained_tok2vec(nlp, init_tok2vec) + msg.text("Loaded pretrained tok2vec for: {}".format(components)) optimizer = create_default_optimizer(model.ops) tracker = ProgressTracker(frequency=10000) msg.divider("Pre-training tok2vec layer") diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index 7788d7a8f..8b982309a 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -305,6 +305,7 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width] | `--n-iter`, `-i` | option | Number of iterations to pretrain. | | `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. | | `--n-save_every`, `-se` | option | Save model every X batches. | +| `--init-tok2vec`, `-t2v` 2.1 | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.| | **CREATES** | weights | The pre-trained weights that can be used to initialize `spacy train`. | ### JSONL format for raw text {#pretrain-jsonl}