diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index 75663c03d..14855cb11 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -13,6 +13,7 @@ import spacy import spacy.util from ..tokens import Token, Doc from ..gold import GoldParse +from ..util import compounding from ..syntax.nonproj import projectivize from ..matcher import Matcher from collections import defaultdict, Counter @@ -36,7 +37,7 @@ lang.ja.Japanese.Defaults.use_janome = False random.seed(0) numpy.random.seed(0) -def minibatch_by_words(items, size=5000): +def minibatch_by_words(items, size): random.shuffle(items) if isinstance(size, int): size_ = itertools.repeat(size) @@ -368,9 +369,10 @@ def main(ud_dir, parses_dir, config, corpus, limit=0): optimizer = initialize_pipeline(nlp, docs, golds, config) + batch_sizes = compounding(config.batch_size //10, config.batch_size, 1.001) for i in range(config.nr_epoch): docs = [nlp.make_doc(doc.text) for doc in docs] - batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) + batches = minibatch_by_words(list(zip(docs, golds)), size=batch_sizes) losses = {} n_train_words = sum(len(doc) for doc in docs) with tqdm.tqdm(total=n_train_words, leave=False) as pbar: