From 77af0a6bb48721f43fba2715191d0fe79867f0b7 Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Thu, 9 Jul 2020 14:50:20 +0200 Subject: [PATCH] Offer option of padding-sensitive batching --- spacy/cli/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index bda3c9ca2..2f1556beb 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -303,11 +303,19 @@ def create_train_batches(nlp, corpus, cfg): ) epoch = 0 + batch_strategy = cfg.get("batch_by", "sequences") while True: if len(train_examples) == 0: raise ValueError(Errors.E988) epoch += 1 - if cfg.get("batch_by_words", True): + if batch_strategy == "padded": + batches = util.minibatch_by_padded_size( + train_examples, + size=cfg["batch_size"], + buffer=256, + discard_oversize=cfg["discard_oversize"], + ) + elif batch_strategy == "words": batches = util.minibatch_by_words( train_examples, size=cfg["batch_size"], @@ -318,7 +326,7 @@ def create_train_batches(nlp, corpus, cfg): train_examples, size=cfg["batch_size"], ) - + # make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop try: first = next(batches)