From 6a27a4f77c78b228325cf2e1211c8ef9ddee98b8 Mon Sep 17 00:00:00 2001
From: Matthew Honnibal <honnibal+gh@gmail.com>
Date: Wed, 21 Feb 2018 21:02:41 +0100
Subject: [PATCH] Set accelerating batch size in CONLL train script

---
 examples/training/conllu.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/examples/training/conllu.py b/examples/training/conllu.py
index 867501844..fa4fefcea 100644
--- a/examples/training/conllu.py
+++ b/examples/training/conllu.py
@@ -218,13 +218,18 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
     n_train_words = sum(len(doc) for doc in docs)
     print(n_train_words)
     print("Begin training")
+    # Batch size starts at 1 and grows, so that we make updates quickly
+    # at the beginning of training.
+    batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
+                                   spacy.util.env_opt('batch_to', 8),
+                                   spacy.util.env_opt('batch_compound', 1.001))
     for i in range(10):
         with open(text_train_loc) as file_:
             docs = get_docs(nlp, file_.read())
         docs = docs[:len(golds)]
         with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
             losses = {}
-            for batch in minibatch(list(zip(docs, golds)), size=1):
+            for batch in minibatch(list(zip(docs, golds)), size=batch_sizes):
                 if not batch:
                     continue
                 batch_docs, batch_gold = zip(*batch)